Commit
·
354d4d3
1
Parent(s):
dfcb0a5
Remove everything
Browse files- base_bert.py +0 -248
- bert.py +0 -225
- classifier.py +0 -411
- config.py +0 -222
- data/amazon-polarity.parquet +0 -3
- data/ids-cfimdb-dev.csv +0 -3
- data/ids-cfimdb-test-student.csv +0 -3
- data/ids-cfimdb-train.csv +0 -3
- data/ids-sst-dev.csv +0 -3
- data/ids-sst-test-student.csv +0 -3
- data/ids-sst-train.csv +0 -3
- data/nli-dev.parquet +0 -3
- data/nli-test.parquet +0 -3
- data/nli-train.parquet +0 -3
- data/stsb-dev.parquet +0 -3
- data/stsb-test.parquet +0 -3
- data/stsb-train.parquet +0 -3
- evaluation.py +0 -205
- justfile +0 -14
- optimizer.py +0 -90
- optimizer_test.npy +0 -3
- optimizer_test.py +0 -34
- predictions/README +0 -2
- prepare_submit.py +0 -18
- prompt +0 -3
- sanity_check.data +0 -0
- sanity_check.py +0 -19
- setup.sh +0 -13
- tokenizer.py +0 -0
- unsup_simcse.py +0 -252
- utils.py +0 -347
base_bert.py
DELETED
@@ -1,248 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
from torch import device, dtype
|
3 |
-
from config import BertConfig, PretrainedConfig
|
4 |
-
from utils import *
|
5 |
-
|
6 |
-
|
7 |
-
class BertPreTrainedModel(nn.Module):
|
8 |
-
config_class = BertConfig
|
9 |
-
base_model_prefix = "bert"
|
10 |
-
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
11 |
-
_keys_to_ignore_on_load_unexpected = None
|
12 |
-
|
13 |
-
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
|
14 |
-
super().__init__()
|
15 |
-
self.config = config
|
16 |
-
self.name_or_path = config.name_or_path
|
17 |
-
|
18 |
-
def init_weights(self):
|
19 |
-
# Initialize weights
|
20 |
-
self.apply(self._init_weights)
|
21 |
-
|
22 |
-
def _init_weights(self, module):
|
23 |
-
""" Initialize the weights """
|
24 |
-
if isinstance(module, (nn.Linear, nn.Embedding)):
|
25 |
-
# Slightly different from the TF version which uses truncated_normal for initialization
|
26 |
-
# cf https://github.com/pytorch/pytorch/pull/5617
|
27 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
28 |
-
elif isinstance(module, nn.LayerNorm):
|
29 |
-
module.bias.data.zero_()
|
30 |
-
module.weight.data.fill_(1.0)
|
31 |
-
if isinstance(module, nn.Linear) and module.bias is not None:
|
32 |
-
module.bias.data.zero_()
|
33 |
-
|
34 |
-
@property
|
35 |
-
def dtype(self) -> dtype:
|
36 |
-
return get_parameter_dtype(self)
|
37 |
-
|
38 |
-
@classmethod
|
39 |
-
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
40 |
-
config = kwargs.pop("config", None)
|
41 |
-
state_dict = kwargs.pop("state_dict", None)
|
42 |
-
cache_dir = kwargs.pop("cache_dir", None)
|
43 |
-
force_download = kwargs.pop("force_download", False)
|
44 |
-
resume_download = kwargs.pop("resume_download", False)
|
45 |
-
proxies = kwargs.pop("proxies", None)
|
46 |
-
output_loading_info = kwargs.pop("output_loading_info", False)
|
47 |
-
local_files_only = kwargs.pop("local_files_only", False)
|
48 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
49 |
-
revision = kwargs.pop("revision", None)
|
50 |
-
mirror = kwargs.pop("mirror", None)
|
51 |
-
|
52 |
-
# Load config if we don't provide a configuration
|
53 |
-
if not isinstance(config, PretrainedConfig):
|
54 |
-
config_path = config if config is not None else pretrained_model_name_or_path
|
55 |
-
config, model_kwargs = cls.config_class.from_pretrained(
|
56 |
-
config_path,
|
57 |
-
*model_args,
|
58 |
-
cache_dir=cache_dir,
|
59 |
-
return_unused_kwargs=True,
|
60 |
-
force_download=force_download,
|
61 |
-
resume_download=resume_download,
|
62 |
-
proxies=proxies,
|
63 |
-
local_files_only=local_files_only,
|
64 |
-
use_auth_token=use_auth_token,
|
65 |
-
revision=revision,
|
66 |
-
**kwargs,
|
67 |
-
)
|
68 |
-
else:
|
69 |
-
model_kwargs = kwargs
|
70 |
-
|
71 |
-
# Load model
|
72 |
-
if pretrained_model_name_or_path is not None:
|
73 |
-
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
74 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
75 |
-
# Load from a PyTorch checkpoint
|
76 |
-
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
77 |
-
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
78 |
-
archive_file = pretrained_model_name_or_path
|
79 |
-
else:
|
80 |
-
archive_file = hf_bucket_url(
|
81 |
-
pretrained_model_name_or_path,
|
82 |
-
filename=WEIGHTS_NAME,
|
83 |
-
revision=revision,
|
84 |
-
mirror=mirror,
|
85 |
-
)
|
86 |
-
try:
|
87 |
-
# Load from URL or cache if already cached
|
88 |
-
resolved_archive_file = cached_path(
|
89 |
-
archive_file,
|
90 |
-
cache_dir=cache_dir,
|
91 |
-
force_download=force_download,
|
92 |
-
proxies=proxies,
|
93 |
-
resume_download=resume_download,
|
94 |
-
local_files_only=local_files_only,
|
95 |
-
use_auth_token=use_auth_token,
|
96 |
-
)
|
97 |
-
except EnvironmentError as err:
|
98 |
-
#logger.error(err)
|
99 |
-
msg = (
|
100 |
-
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
101 |
-
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
102 |
-
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
|
103 |
-
)
|
104 |
-
raise EnvironmentError(msg)
|
105 |
-
else:
|
106 |
-
resolved_archive_file = None
|
107 |
-
|
108 |
-
config.name_or_path = pretrained_model_name_or_path
|
109 |
-
|
110 |
-
# Instantiate model.
|
111 |
-
model = cls(config, *model_args, **model_kwargs)
|
112 |
-
|
113 |
-
if state_dict is None:
|
114 |
-
try:
|
115 |
-
state_dict = torch.load(resolved_archive_file, map_location="cpu", weights_only=True)
|
116 |
-
except Exception:
|
117 |
-
raise OSError(
|
118 |
-
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
119 |
-
f"at '{resolved_archive_file}'"
|
120 |
-
)
|
121 |
-
|
122 |
-
missing_keys = []
|
123 |
-
unexpected_keys = []
|
124 |
-
error_msgs = []
|
125 |
-
|
126 |
-
# Convert old format to new format if needed from a PyTorch state_dict
|
127 |
-
old_keys = []
|
128 |
-
new_keys = []
|
129 |
-
m = {'embeddings.word_embeddings': 'word_embedding',
|
130 |
-
'embeddings.position_embeddings': 'pos_embedding',
|
131 |
-
'embeddings.token_type_embeddings': 'tk_type_embedding',
|
132 |
-
'embeddings.LayerNorm': 'embed_layer_norm',
|
133 |
-
'embeddings.dropout': 'embed_dropout',
|
134 |
-
'encoder.layer': 'bert_layers',
|
135 |
-
'pooler.dense': 'pooler_dense',
|
136 |
-
'pooler.activation': 'pooler_af',
|
137 |
-
'attention.self': "self_attention",
|
138 |
-
'attention.output.dense': 'attention_dense',
|
139 |
-
'attention.output.LayerNorm': 'attention_layer_norm',
|
140 |
-
'attention.output.dropout': 'attention_dropout',
|
141 |
-
'intermediate.dense': 'interm_dense',
|
142 |
-
'intermediate.intermediate_act_fn': 'interm_af',
|
143 |
-
'output.dense': 'out_dense',
|
144 |
-
'output.LayerNorm': 'out_layer_norm',
|
145 |
-
'output.dropout': 'out_dropout'}
|
146 |
-
|
147 |
-
for key in state_dict.keys():
|
148 |
-
new_key = None
|
149 |
-
if "gamma" in key:
|
150 |
-
new_key = key.replace("gamma", "weight")
|
151 |
-
if "beta" in key:
|
152 |
-
new_key = key.replace("beta", "bias")
|
153 |
-
for x, y in m.items():
|
154 |
-
if new_key is not None:
|
155 |
-
_key = new_key
|
156 |
-
else:
|
157 |
-
_key = key
|
158 |
-
if x in key:
|
159 |
-
new_key = _key.replace(x, y)
|
160 |
-
if new_key:
|
161 |
-
old_keys.append(key)
|
162 |
-
new_keys.append(new_key)
|
163 |
-
|
164 |
-
for old_key, new_key in zip(old_keys, new_keys):
|
165 |
-
# print(old_key, new_key)
|
166 |
-
state_dict[new_key] = state_dict.pop(old_key)
|
167 |
-
|
168 |
-
# copy state_dict so _load_from_state_dict can modify it
|
169 |
-
metadata = getattr(state_dict, "_metadata", None)
|
170 |
-
state_dict = state_dict.copy()
|
171 |
-
if metadata is not None:
|
172 |
-
state_dict._metadata = metadata
|
173 |
-
|
174 |
-
your_bert_params = [f"bert.{x[0]}" for x in model.named_parameters()]
|
175 |
-
for k in state_dict:
|
176 |
-
if k not in your_bert_params and not k.startswith("cls."):
|
177 |
-
possible_rename = [x for x in k.split(".")[1:-1] if x in m.values()]
|
178 |
-
raise ValueError(f"{k} cannot be reload to your model, one/some of {possible_rename} we provided have been renamed")
|
179 |
-
|
180 |
-
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
181 |
-
# so we need to apply the function recursively.
|
182 |
-
def load(module: nn.Module, prefix=""):
|
183 |
-
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
184 |
-
module._load_from_state_dict(
|
185 |
-
state_dict,
|
186 |
-
prefix,
|
187 |
-
local_metadata,
|
188 |
-
True,
|
189 |
-
missing_keys,
|
190 |
-
unexpected_keys,
|
191 |
-
error_msgs,
|
192 |
-
)
|
193 |
-
for name, child in module._modules.items():
|
194 |
-
if child is not None:
|
195 |
-
load(child, prefix + name + ".")
|
196 |
-
|
197 |
-
# Make sure we are able to load base models as well as derived models (with heads)
|
198 |
-
start_prefix = ""
|
199 |
-
model_to_load = model
|
200 |
-
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
|
201 |
-
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
202 |
-
start_prefix = cls.base_model_prefix + "."
|
203 |
-
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
204 |
-
model_to_load = getattr(model, cls.base_model_prefix)
|
205 |
-
load(model_to_load, prefix=start_prefix)
|
206 |
-
|
207 |
-
if model.__class__.__name__ != model_to_load.__class__.__name__:
|
208 |
-
base_model_state_dict = model_to_load.state_dict().keys()
|
209 |
-
head_model_state_dict_without_base_prefix = [
|
210 |
-
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
211 |
-
]
|
212 |
-
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
213 |
-
|
214 |
-
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
215 |
-
# the user.
|
216 |
-
if cls._keys_to_ignore_on_load_missing is not None:
|
217 |
-
for pat in cls._keys_to_ignore_on_load_missing:
|
218 |
-
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
219 |
-
|
220 |
-
if cls._keys_to_ignore_on_load_unexpected is not None:
|
221 |
-
for pat in cls._keys_to_ignore_on_load_unexpected:
|
222 |
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
223 |
-
|
224 |
-
if len(error_msgs) > 0:
|
225 |
-
raise RuntimeError(
|
226 |
-
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
227 |
-
model.__class__.__name__, "\n\t".join(error_msgs)
|
228 |
-
)
|
229 |
-
)
|
230 |
-
|
231 |
-
# Set model in evaluation mode to deactivate DropOut modules by default
|
232 |
-
model.eval()
|
233 |
-
|
234 |
-
if output_loading_info:
|
235 |
-
loading_info = {
|
236 |
-
"missing_keys": missing_keys,
|
237 |
-
"unexpected_keys": unexpected_keys,
|
238 |
-
"error_msgs": error_msgs,
|
239 |
-
}
|
240 |
-
return model, loading_info
|
241 |
-
|
242 |
-
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
|
243 |
-
import torch_xla.core.xla_model as xm
|
244 |
-
|
245 |
-
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
246 |
-
model.to(xm.xla_device())
|
247 |
-
|
248 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bert.py
DELETED
@@ -1,225 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from base_bert import BertPreTrainedModel
|
6 |
-
from utils import *
|
7 |
-
|
8 |
-
|
9 |
-
class BertSelfAttention(nn.Module):
|
10 |
-
def __init__(self, config):
|
11 |
-
super().__init__()
|
12 |
-
|
13 |
-
self.num_attention_heads = config.num_attention_heads
|
14 |
-
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
15 |
-
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
16 |
-
|
17 |
-
# Initialize the linear transformation layers for key, value, query.
|
18 |
-
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
19 |
-
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
20 |
-
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
21 |
-
# This dropout is applied to normalized attention scores following the original
|
22 |
-
# implementation of transformer. Although it is a bit unusual, we empirically
|
23 |
-
# observe that it yields better performance.
|
24 |
-
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
25 |
-
|
26 |
-
def transform(self, x, linear_layer):
|
27 |
-
# The corresponding linear_layer of k, v, q are used to project the hidden_state (x).
|
28 |
-
bs, seq_len = x.shape[:2]
|
29 |
-
proj = linear_layer(x)
|
30 |
-
# Next, we need to produce multiple heads for the proj. This is done by spliting the
|
31 |
-
# hidden state to self.num_attention_heads, each of size self.attention_head_size.
|
32 |
-
proj = proj.view(bs, seq_len, self.num_attention_heads, self.attention_head_size)
|
33 |
-
# By proper transpose, we have proj of size [bs, num_attention_heads, seq_len, attention_head_size].
|
34 |
-
proj = proj.transpose(1, 2)
|
35 |
-
return proj
|
36 |
-
|
37 |
-
def attention(self, key, query, value, attention_mask):
|
38 |
-
"""
|
39 |
-
key, query, value: [batch_size, num_attention_heads, seq_len, attention_head_size]
|
40 |
-
attention_mask: [batch_size, 1, 1, seq_len], masks padding tokens in the input.
|
41 |
-
"""
|
42 |
-
|
43 |
-
d_k = query.size(-1) # attention_head_size
|
44 |
-
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
|
45 |
-
# attention_scores shape: [batch_size, num_attention_heads, seq_len, seq_len]
|
46 |
-
|
47 |
-
# Apply attention mask
|
48 |
-
attention_scores = attention_scores + attention_mask
|
49 |
-
|
50 |
-
# Normalize scores with softmax and apply dropout.
|
51 |
-
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
52 |
-
attention_probs = self.dropout(attention_probs)
|
53 |
-
|
54 |
-
context = torch.matmul(attention_probs, value)
|
55 |
-
# context shape: [batch_size, num_attention_heads, seq_len, attention_head_size]
|
56 |
-
|
57 |
-
# Concatenate all attention heads to recover original shape: [batch_size, seq_len, hidden_size]
|
58 |
-
context = context.transpose(1, 2).contiguous()
|
59 |
-
context = context.view(context.size(0), context.size(1), -1)
|
60 |
-
|
61 |
-
return context
|
62 |
-
|
63 |
-
|
64 |
-
def forward(self, hidden_states, attention_mask):
|
65 |
-
"""
|
66 |
-
hidden_states: [bs, seq_len, hidden_state]
|
67 |
-
attention_mask: [bs, 1, 1, seq_len]
|
68 |
-
output: [bs, seq_len, hidden_state]
|
69 |
-
"""
|
70 |
-
# First, we have to generate the key, value, query for each token for multi-head attention
|
71 |
-
# using self.transform (more details inside the function).
|
72 |
-
# Size of *_layer is [bs, num_attention_heads, seq_len, attention_head_size].
|
73 |
-
key_layer = self.transform(hidden_states, self.key)
|
74 |
-
value_layer = self.transform(hidden_states, self.value)
|
75 |
-
query_layer = self.transform(hidden_states, self.query)
|
76 |
-
# Calculate the multi-head attention.
|
77 |
-
attn_value = self.attention(key_layer, query_layer, value_layer, attention_mask)
|
78 |
-
return attn_value
|
79 |
-
|
80 |
-
|
81 |
-
class BertLayer(nn.Module):
|
82 |
-
def __init__(self, config):
|
83 |
-
super().__init__()
|
84 |
-
# Multi-head attention.
|
85 |
-
self.self_attention = BertSelfAttention(config)
|
86 |
-
# Add-norm for multi-head attention.
|
87 |
-
self.attention_dense = nn.Linear(config.hidden_size, config.hidden_size)
|
88 |
-
self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
89 |
-
self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
|
90 |
-
# Feed forward.
|
91 |
-
self.interm_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
92 |
-
self.interm_af = F.gelu
|
93 |
-
# Add-norm for feed forward.
|
94 |
-
self.out_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
95 |
-
self.out_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
96 |
-
self.out_dropout = nn.Dropout(config.hidden_dropout_prob)
|
97 |
-
|
98 |
-
|
99 |
-
def add_norm(self, input, output, dense_layer, dropout, ln_layer):
|
100 |
-
transformed_output = dense_layer(output) # Biến đổi output bằng dense_layer
|
101 |
-
transformed_output = dropout(transformed_output) # Áp dụng dropout
|
102 |
-
added_output = input + transformed_output # Kết hợp input và output
|
103 |
-
normalized_output = ln_layer(added_output) # Áp dụng chuẩn hóa
|
104 |
-
return normalized_output
|
105 |
-
|
106 |
-
|
107 |
-
def forward(self, hidden_states, attention_mask):
|
108 |
-
# 1. Multi-head attention
|
109 |
-
attention_output = self.self_attention(hidden_states, attention_mask)
|
110 |
-
|
111 |
-
# 2. Add-norm after attention
|
112 |
-
attention_output = self.add_norm(
|
113 |
-
hidden_states,
|
114 |
-
attention_output,
|
115 |
-
self.attention_dense,
|
116 |
-
self.attention_dropout,
|
117 |
-
self.attention_layer_norm
|
118 |
-
)
|
119 |
-
|
120 |
-
# 3. Feed-forward network
|
121 |
-
intermediate_output = self.interm_af(self.interm_dense(attention_output))
|
122 |
-
|
123 |
-
# 4. Add-norm after feed-forward
|
124 |
-
layer_output = self.add_norm(
|
125 |
-
attention_output,
|
126 |
-
intermediate_output,
|
127 |
-
self.out_dense,
|
128 |
-
self.out_dropout,
|
129 |
-
self.out_layer_norm
|
130 |
-
)
|
131 |
-
|
132 |
-
return layer_output
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
class BertModel(BertPreTrainedModel):
|
138 |
-
"""
|
139 |
-
The BERT model returns the final embeddings for each token in a sentence.
|
140 |
-
|
141 |
-
The model consists of:
|
142 |
-
1. Embedding layers (used in self.embed).
|
143 |
-
2. A stack of n BERT layers (used in self.encode).
|
144 |
-
3. A linear transformation layer for the [CLS] token (used in self.forward, as given).
|
145 |
-
"""
|
146 |
-
def __init__(self, config):
|
147 |
-
super().__init__(config)
|
148 |
-
self.config = config
|
149 |
-
|
150 |
-
# Embedding layers.
|
151 |
-
self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
152 |
-
self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
153 |
-
self.tk_type_embedding = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
154 |
-
self.embed_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
155 |
-
self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
|
156 |
-
# Register position_ids (1, len position emb) to buffer because it is a constant.
|
157 |
-
position_ids = torch.arange(config.max_position_embeddings).unsqueeze(0)
|
158 |
-
self.register_buffer('position_ids', position_ids)
|
159 |
-
|
160 |
-
# BERT encoder.
|
161 |
-
self.bert_layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
162 |
-
|
163 |
-
# [CLS] token transformations.
|
164 |
-
self.pooler_dense = nn.Linear(config.hidden_size, config.hidden_size)
|
165 |
-
self.pooler_af = nn.Tanh()
|
166 |
-
|
167 |
-
self.init_weights()
|
168 |
-
|
169 |
-
|
170 |
-
def embed(self, input_ids):
|
171 |
-
input_shape = input_ids.size()
|
172 |
-
seq_length = input_shape[1]
|
173 |
-
|
174 |
-
inputs_embeds = self.word_embedding(input_ids)
|
175 |
-
|
176 |
-
pos_ids = self.position_ids[:, :seq_length]
|
177 |
-
pos_embeds = self.pos_embedding(pos_ids)
|
178 |
-
|
179 |
-
# Since we are not considering token type, this embedding is just a placeholder.
|
180 |
-
tk_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
|
181 |
-
tk_type_embeds = self.tk_type_embedding(tk_type_ids)
|
182 |
-
|
183 |
-
embeddings = inputs_embeds + pos_embeds + tk_type_embeds
|
184 |
-
embeddings = self.embed_layer_norm(embeddings)
|
185 |
-
embeddings = self.embed_dropout(embeddings)
|
186 |
-
|
187 |
-
return embeddings
|
188 |
-
|
189 |
-
|
190 |
-
def encode(self, hidden_states, attention_mask):
|
191 |
-
"""
|
192 |
-
hidden_states: the output from the embedding layer [batch_size, seq_len, hidden_size]
|
193 |
-
attention_mask: [batch_size, seq_len]
|
194 |
-
"""
|
195 |
-
# Get the extended attention mask for self-attention.
|
196 |
-
# Returns extended_attention_mask of size [batch_size, 1, 1, seq_len].
|
197 |
-
# Distinguishes between non-padding tokens (with a value of 0) and padding tokens
|
198 |
-
# (with a value of a large negative number).
|
199 |
-
extended_attention_mask: torch.Tensor = get_extended_attention_mask(attention_mask, self.dtype)
|
200 |
-
|
201 |
-
# Pass the hidden states through the encoder layers.
|
202 |
-
for i, layer_module in enumerate(self.bert_layers):
|
203 |
-
# Feed the encoding from the last bert_layer to the next.
|
204 |
-
hidden_states = layer_module(hidden_states, extended_attention_mask)
|
205 |
-
|
206 |
-
return hidden_states
|
207 |
-
|
208 |
-
|
209 |
-
def forward(self, input_ids, attention_mask):
|
210 |
-
"""
|
211 |
-
input_ids: [batch_size, seq_len], seq_len is the max length of the batch
|
212 |
-
attention_mask: same size as input_ids, 1 represents non-padding tokens, 0 represents padding tokens
|
213 |
-
"""
|
214 |
-
# Get the embedding for each input token.
|
215 |
-
embedding_output = self.embed(input_ids=input_ids)
|
216 |
-
|
217 |
-
# Feed to a transformer (a stack of BertLayers).
|
218 |
-
sequence_output = self.encode(embedding_output, attention_mask=attention_mask)
|
219 |
-
|
220 |
-
# Get cls token hidden state.
|
221 |
-
first_tk = sequence_output[:, 0]
|
222 |
-
first_tk = self.pooler_dense(first_tk)
|
223 |
-
first_tk = self.pooler_af(first_tk)
|
224 |
-
|
225 |
-
return {'last_hidden_state': sequence_output, 'pooler_output': first_tk}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
classifier.py
DELETED
@@ -1,411 +0,0 @@
|
|
1 |
-
import random, numpy as np, argparse
|
2 |
-
from types import SimpleNamespace
|
3 |
-
import csv
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from tqdm import tqdm
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from torch.utils.data import Dataset, DataLoader
|
9 |
-
from sklearn.metrics import f1_score, accuracy_score
|
10 |
-
|
11 |
-
from tokenizer import BertTokenizer
|
12 |
-
from bert import BertModel
|
13 |
-
from optimizer import AdamW
|
14 |
-
|
15 |
-
|
16 |
-
TQDM_DISABLE=True
|
17 |
-
|
18 |
-
|
19 |
-
# Fix the random seed.
|
20 |
-
def seed_everything(seed=11711):
|
21 |
-
random.seed(seed)
|
22 |
-
np.random.seed(seed)
|
23 |
-
torch.manual_seed(seed)
|
24 |
-
torch.cuda.manual_seed(seed)
|
25 |
-
torch.cuda.manual_seed_all(seed)
|
26 |
-
torch.backends.cudnn.benchmark = False
|
27 |
-
torch.backends.cudnn.deterministic = True
|
28 |
-
|
29 |
-
|
30 |
-
class BertSentimentClassifier(torch.nn.Module):
|
31 |
-
'''
|
32 |
-
This module performs sentiment classification using BERT embeddings on the SST dataset.
|
33 |
-
|
34 |
-
In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").
|
35 |
-
Thus, your forward() should return one logit for each of the 5 classes.
|
36 |
-
'''
|
37 |
-
def __init__(self, config, bert_model = None):
|
38 |
-
super(BertSentimentClassifier, self).__init__()
|
39 |
-
self.num_labels = config.num_labels
|
40 |
-
self.bert: BertModel = bert_model or BertModel.from_pretrained('bert-base-uncased')
|
41 |
-
|
42 |
-
# Pretrain mode does not require updating BERT paramters.
|
43 |
-
assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
|
44 |
-
for param in self.bert.parameters():
|
45 |
-
if config.fine_tune_mode == 'last-linear-layer':
|
46 |
-
param.requires_grad = False
|
47 |
-
elif config.fine_tune_mode == 'full-model':
|
48 |
-
param.requires_grad = True
|
49 |
-
|
50 |
-
# Create any instance variables you need to classify the sentiment of BERT embeddings.
|
51 |
-
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
|
52 |
-
self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
|
53 |
-
|
54 |
-
|
55 |
-
def forward(self, input_ids, attention_mask):
|
56 |
-
'''Takes a batch of sentences and returns logits for sentiment classes'''
|
57 |
-
# The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).
|
58 |
-
# HINT: You should consider what is an appropriate return value given that
|
59 |
-
# the training loop currently uses F.cross_entropy as the loss function.
|
60 |
-
|
61 |
-
# Get the embedding for each input token.
|
62 |
-
outputs = self.bert(input_ids, attention_mask)
|
63 |
-
pooler_output = outputs['pooler_output']
|
64 |
-
|
65 |
-
# Pass the [CLS] token representation through the classifier.
|
66 |
-
logits = self.classifier(self.dropout(pooler_output))
|
67 |
-
|
68 |
-
return logits
|
69 |
-
|
70 |
-
|
71 |
-
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
72 |
-
|
73 |
-
class SentimentDataset(Dataset):
|
74 |
-
def __init__(self, dataset, args):
|
75 |
-
self.dataset = dataset
|
76 |
-
self.p = args
|
77 |
-
|
78 |
-
def __len__(self):
|
79 |
-
return len(self.dataset)
|
80 |
-
|
81 |
-
def __getitem__(self, idx):
|
82 |
-
return self.dataset[idx]
|
83 |
-
|
84 |
-
def pad_data(self, data):
|
85 |
-
sents = [x[0] for x in data]
|
86 |
-
labels = [x[1] for x in data]
|
87 |
-
sent_ids = [x[2] for x in data]
|
88 |
-
|
89 |
-
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
90 |
-
token_ids = torch.LongTensor(encoding['input_ids'])
|
91 |
-
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
92 |
-
labels = torch.LongTensor(labels)
|
93 |
-
|
94 |
-
return token_ids, attention_mask, labels, sents, sent_ids
|
95 |
-
|
96 |
-
def collate_fn(self, all_data):
|
97 |
-
token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data)
|
98 |
-
|
99 |
-
batched_data = {
|
100 |
-
'token_ids': token_ids,
|
101 |
-
'attention_mask': attention_mask,
|
102 |
-
'labels': labels,
|
103 |
-
'sents': sents,
|
104 |
-
'sent_ids': sent_ids
|
105 |
-
}
|
106 |
-
|
107 |
-
return batched_data
|
108 |
-
|
109 |
-
|
110 |
-
class SentimentTestDataset(Dataset):
|
111 |
-
def __init__(self, dataset, args):
|
112 |
-
self.dataset = dataset
|
113 |
-
self.p = args
|
114 |
-
|
115 |
-
def __len__(self):
|
116 |
-
return len(self.dataset)
|
117 |
-
|
118 |
-
def __getitem__(self, idx):
|
119 |
-
return self.dataset[idx]
|
120 |
-
|
121 |
-
def pad_data(self, data):
|
122 |
-
sents = [x[0] for x in data]
|
123 |
-
sent_ids = [x[1] for x in data]
|
124 |
-
|
125 |
-
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
126 |
-
token_ids = torch.LongTensor(encoding['input_ids'])
|
127 |
-
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
128 |
-
|
129 |
-
return token_ids, attention_mask, sents, sent_ids
|
130 |
-
|
131 |
-
def collate_fn(self, all_data):
|
132 |
-
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
|
133 |
-
|
134 |
-
batched_data = {
|
135 |
-
'token_ids': token_ids,
|
136 |
-
'attention_mask': attention_mask,
|
137 |
-
'sents': sents,
|
138 |
-
'sent_ids': sent_ids
|
139 |
-
}
|
140 |
-
|
141 |
-
return batched_data
|
142 |
-
|
143 |
-
|
144 |
-
# Load the data: a list of (sentence, label).
|
145 |
-
def load_data(filename, flag='train'):
|
146 |
-
num_labels = set()
|
147 |
-
data = []
|
148 |
-
with open(filename, 'r') as fp:
|
149 |
-
for record in csv.DictReader(fp, delimiter = '\t'):
|
150 |
-
if flag == 'test':
|
151 |
-
sent = record['sentence'].lower().strip()
|
152 |
-
sent_id = record['id'].lower().strip()
|
153 |
-
data.append((sent,sent_id))
|
154 |
-
else:
|
155 |
-
sent = record['sentence'].lower().strip()
|
156 |
-
sent_id = record['id'].lower().strip()
|
157 |
-
label = int(record['sentiment'].strip())
|
158 |
-
num_labels.add(label)
|
159 |
-
data.append((sent, label, sent_id))
|
160 |
-
print(f"load {len(data)} data from {filename}")
|
161 |
-
|
162 |
-
if flag == 'train':
|
163 |
-
return data, len(num_labels)
|
164 |
-
else:
|
165 |
-
return data
|
166 |
-
|
167 |
-
|
168 |
-
# Evaluate the model on dev examples.
|
169 |
-
def model_eval(dataloader, model, device):
|
170 |
-
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
171 |
-
y_true = []
|
172 |
-
y_pred = []
|
173 |
-
sents = []
|
174 |
-
sent_ids = []
|
175 |
-
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)):
|
176 |
-
b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
|
177 |
-
batch['labels'], batch['sents'], batch['sent_ids']
|
178 |
-
|
179 |
-
b_ids = b_ids.to(device)
|
180 |
-
b_mask = b_mask.to(device)
|
181 |
-
|
182 |
-
logits = model(b_ids, b_mask)
|
183 |
-
logits = logits.detach().cpu().numpy()
|
184 |
-
preds = np.argmax(logits, axis=1).flatten()
|
185 |
-
|
186 |
-
b_labels = b_labels.flatten()
|
187 |
-
y_true.extend(b_labels)
|
188 |
-
y_pred.extend(preds)
|
189 |
-
sents.extend(b_sents)
|
190 |
-
sent_ids.extend(b_sent_ids)
|
191 |
-
|
192 |
-
f1 = f1_score(y_true, y_pred, average='macro')
|
193 |
-
acc = accuracy_score(y_true, y_pred)
|
194 |
-
|
195 |
-
return acc, f1, y_pred, y_true, sents, sent_ids
|
196 |
-
|
197 |
-
|
198 |
-
# Evaluate the model on test examples.
|
199 |
-
def model_test_eval(dataloader, model, device):
|
200 |
-
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
201 |
-
y_pred = []
|
202 |
-
sents = []
|
203 |
-
sent_ids = []
|
204 |
-
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)):
|
205 |
-
b_ids, b_mask, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
|
206 |
-
batch['sents'], batch['sent_ids']
|
207 |
-
|
208 |
-
b_ids = b_ids.to(device)
|
209 |
-
b_mask = b_mask.to(device)
|
210 |
-
|
211 |
-
logits = model(b_ids, b_mask)
|
212 |
-
logits = logits.detach().cpu().numpy()
|
213 |
-
preds = np.argmax(logits, axis=1).flatten()
|
214 |
-
|
215 |
-
y_pred.extend(preds)
|
216 |
-
sents.extend(b_sents)
|
217 |
-
sent_ids.extend(b_sent_ids)
|
218 |
-
|
219 |
-
return y_pred, sents, sent_ids
|
220 |
-
|
221 |
-
|
222 |
-
def save_model(model, optimizer, args, config, filepath):
|
223 |
-
save_info = {
|
224 |
-
'model': model.state_dict(),
|
225 |
-
'optim': optimizer.state_dict(),
|
226 |
-
'args': args,
|
227 |
-
'model_config': config,
|
228 |
-
'system_rng': random.getstate(),
|
229 |
-
'numpy_rng': np.random.get_state(),
|
230 |
-
'torch_rng': torch.random.get_rng_state(),
|
231 |
-
}
|
232 |
-
|
233 |
-
torch.save(save_info, filepath)
|
234 |
-
print(f"save the model to {filepath}")
|
235 |
-
|
236 |
-
|
237 |
-
def train(args):
|
238 |
-
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
239 |
-
# Create the data and its corresponding datasets and dataloader.
|
240 |
-
train_data, num_labels = load_data(args.train, 'train')
|
241 |
-
dev_data = load_data(args.dev, 'valid')
|
242 |
-
|
243 |
-
train_dataset = SentimentDataset(train_data, args)
|
244 |
-
dev_dataset = SentimentDataset(dev_data, args)
|
245 |
-
|
246 |
-
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
|
247 |
-
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
|
248 |
-
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
|
249 |
-
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
|
250 |
-
|
251 |
-
# Init model.
|
252 |
-
config = {'hidden_dropout_prob': args.hidden_dropout_prob,
|
253 |
-
'num_labels': num_labels,
|
254 |
-
'hidden_size': 768,
|
255 |
-
'data_dir': '.',
|
256 |
-
'fine_tune_mode': args.fine_tune_mode}
|
257 |
-
|
258 |
-
config = SimpleNamespace(**config)
|
259 |
-
|
260 |
-
model = BertSentimentClassifier(config)
|
261 |
-
model = model.to(device)
|
262 |
-
|
263 |
-
lr = args.lr
|
264 |
-
optimizer = AdamW(model.parameters(), lr=lr)
|
265 |
-
best_dev_acc = 0
|
266 |
-
|
267 |
-
# Run for the specified number of epochs.
|
268 |
-
for epoch in range(args.epochs):
|
269 |
-
model.train()
|
270 |
-
train_loss = 0
|
271 |
-
num_batches = 0
|
272 |
-
for batch in tqdm(train_dataloader, desc=f'train-{epoch}', leave=False, disable=TQDM_DISABLE):
|
273 |
-
b_ids, b_mask, b_labels = (batch['token_ids'],
|
274 |
-
batch['attention_mask'], batch['labels'])
|
275 |
-
|
276 |
-
b_ids = b_ids.to(device)
|
277 |
-
b_mask = b_mask.to(device)
|
278 |
-
b_labels = b_labels.to(device)
|
279 |
-
|
280 |
-
optimizer.zero_grad()
|
281 |
-
logits = model(b_ids, b_mask)
|
282 |
-
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
|
283 |
-
|
284 |
-
loss.backward()
|
285 |
-
optimizer.step()
|
286 |
-
|
287 |
-
train_loss += loss.item()
|
288 |
-
num_batches += 1
|
289 |
-
|
290 |
-
train_loss = train_loss / (num_batches)
|
291 |
-
|
292 |
-
train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
|
293 |
-
dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)
|
294 |
-
|
295 |
-
if dev_acc > best_dev_acc:
|
296 |
-
best_dev_acc = dev_acc
|
297 |
-
save_model(model, optimizer, args, config, args.filepath)
|
298 |
-
|
299 |
-
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
|
300 |
-
|
301 |
-
|
302 |
-
def test(args):
|
303 |
-
with torch.no_grad():
|
304 |
-
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
305 |
-
saved = torch.load(args.filepath, weights_only=False)
|
306 |
-
config = saved['model_config']
|
307 |
-
model = BertSentimentClassifier(config)
|
308 |
-
model.load_state_dict(saved['model'])
|
309 |
-
model = model.to(device)
|
310 |
-
print(f"load model from {args.filepath}")
|
311 |
-
|
312 |
-
dev_data = load_data(args.dev, 'valid')
|
313 |
-
dev_dataset = SentimentDataset(dev_data, args)
|
314 |
-
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
|
315 |
-
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
|
316 |
-
|
317 |
-
dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
|
318 |
-
print('DONE DEV')
|
319 |
-
print(f"dev acc :: {dev_acc :.3f}")
|
320 |
-
|
321 |
-
# ---- SKIP RUNNING ON TEST DATASET ---- #
|
322 |
-
# test_data = load_data(args.test, 'test')
|
323 |
-
# test_dataset = SentimentTestDataset(test_data, args)
|
324 |
-
# test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size,
|
325 |
-
# num_workers=args.num_cpu_cores, collate_fn=test_dataset.collate_fn)
|
326 |
-
# test_pred, test_sents, test_sent_ids = model_test_eval(test_dataloader, model, device)
|
327 |
-
# print('DONE TEST')
|
328 |
-
|
329 |
-
# ---- SKIP SAVING PREDICTIONS ----
|
330 |
-
# with open(args.dev_out, "w+") as f:
|
331 |
-
# f.write(f"id \t Predicted_Sentiment \n")
|
332 |
-
# for p, s in zip(dev_sent_ids,dev_pred):
|
333 |
-
# f.write(f"{p} , {s} \n")
|
334 |
-
# with open(args.test_out, "w+") as f:
|
335 |
-
# f.write(f"id \t Predicted_Sentiment \n")
|
336 |
-
# for p, s in zip(test_sent_ids,test_pred ):
|
337 |
-
# f.write(f"{p} , {s} \n")
|
338 |
-
|
339 |
-
|
340 |
-
def get_args():
|
341 |
-
parser = argparse.ArgumentParser()
|
342 |
-
parser.add_argument("--seed", type=int, default=11711)
|
343 |
-
parser.add_argument("--num-cpu-cores", type=int, default=8)
|
344 |
-
parser.add_argument("--epochs", type=int, default=10)
|
345 |
-
parser.add_argument("--fine-tune-mode", type=str,
|
346 |
-
help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
|
347 |
-
choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
|
348 |
-
parser.add_argument("--use_gpu", action='store_true')
|
349 |
-
|
350 |
-
parser.add_argument("--batch_size_sst", help='64 can fit a 12GB GPU', type=int, default=64)
|
351 |
-
parser.add_argument("--batch_size_cfimdb", help='8 can fit a 12GB GPU', type=int, default=8)
|
352 |
-
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
353 |
-
parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
|
354 |
-
default=1e-3)
|
355 |
-
|
356 |
-
args = parser.parse_args()
|
357 |
-
return args
|
358 |
-
|
359 |
-
|
360 |
-
def main():
|
361 |
-
args = get_args()
|
362 |
-
seed_everything(args.seed)
|
363 |
-
torch.set_num_threads(args.num_cpu_cores)
|
364 |
-
|
365 |
-
print('Training Sentiment Classifier on SST...')
|
366 |
-
config = SimpleNamespace(
|
367 |
-
filepath='sst-classifier.pt',
|
368 |
-
lr=args.lr,
|
369 |
-
num_cpu_cores=args.num_cpu_cores,
|
370 |
-
use_gpu=args.use_gpu,
|
371 |
-
epochs=args.epochs,
|
372 |
-
batch_size=args.batch_size_sst,
|
373 |
-
hidden_dropout_prob=args.hidden_dropout_prob,
|
374 |
-
train='data/ids-sst-train.csv',
|
375 |
-
dev='data/ids-sst-dev.csv',
|
376 |
-
test='data/ids-sst-test-student.csv',
|
377 |
-
fine_tune_mode=args.fine_tune_mode,
|
378 |
-
dev_out = 'predictions/' + args.fine_tune_mode + '-sst-dev-out.csv',
|
379 |
-
test_out = 'predictions/' + args.fine_tune_mode + '-sst-test-out.csv'
|
380 |
-
)
|
381 |
-
|
382 |
-
train(config)
|
383 |
-
|
384 |
-
print('Evaluating on SST...')
|
385 |
-
test(config)
|
386 |
-
|
387 |
-
print('Training Sentiment Classifier on cfimdb...')
|
388 |
-
config = SimpleNamespace(
|
389 |
-
filepath='cfimdb-classifier.pt',
|
390 |
-
lr=args.lr,
|
391 |
-
num_cpu_cores=args.num_cpu_cores,
|
392 |
-
use_gpu=args.use_gpu,
|
393 |
-
epochs=args.epochs,
|
394 |
-
batch_size=args.batch_size_cfimdb,
|
395 |
-
hidden_dropout_prob=args.hidden_dropout_prob,
|
396 |
-
train='data/ids-cfimdb-train.csv',
|
397 |
-
dev='data/ids-cfimdb-dev.csv',
|
398 |
-
test='data/ids-cfimdb-test-student.csv',
|
399 |
-
fine_tune_mode=args.fine_tune_mode,
|
400 |
-
dev_out = 'predictions/' + args.fine_tune_mode + '-cfimdb-dev-out.csv',
|
401 |
-
test_out = 'predictions/' + args.fine_tune_mode + '-cfimdb-test-out.csv'
|
402 |
-
)
|
403 |
-
|
404 |
-
train(config)
|
405 |
-
|
406 |
-
print('Evaluating on cfimdb...')
|
407 |
-
test(config)
|
408 |
-
|
409 |
-
|
410 |
-
if __name__ == "__main__":
|
411 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.py
DELETED
@@ -1,222 +0,0 @@
|
|
1 |
-
from typing import Union, Tuple, Dict, Any, Optional
|
2 |
-
import os
|
3 |
-
import json
|
4 |
-
from collections import OrderedDict
|
5 |
-
import torch
|
6 |
-
from utils import CONFIG_NAME, hf_bucket_url, cached_path, is_remote_url
|
7 |
-
|
8 |
-
class PretrainedConfig(object):
|
9 |
-
model_type: str = ""
|
10 |
-
is_composition: bool = False
|
11 |
-
|
12 |
-
def __init__(self, **kwargs):
|
13 |
-
# Attributes with defaults
|
14 |
-
self.return_dict = kwargs.pop("return_dict", True)
|
15 |
-
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
16 |
-
self.output_attentions = kwargs.pop("output_attentions", False)
|
17 |
-
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
18 |
-
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
19 |
-
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
20 |
-
self.tie_word_embeddings = kwargs.pop(
|
21 |
-
"tie_word_embeddings", True
|
22 |
-
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
|
23 |
-
|
24 |
-
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
25 |
-
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
26 |
-
self.is_decoder = kwargs.pop("is_decoder", False)
|
27 |
-
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
|
28 |
-
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
|
29 |
-
|
30 |
-
# Parameters for sequence generation
|
31 |
-
self.max_length = kwargs.pop("max_length", 20)
|
32 |
-
self.min_length = kwargs.pop("min_length", 0)
|
33 |
-
self.do_sample = kwargs.pop("do_sample", False)
|
34 |
-
self.early_stopping = kwargs.pop("early_stopping", False)
|
35 |
-
self.num_beams = kwargs.pop("num_beams", 1)
|
36 |
-
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
37 |
-
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
38 |
-
self.temperature = kwargs.pop("temperature", 1.0)
|
39 |
-
self.top_k = kwargs.pop("top_k", 50)
|
40 |
-
self.top_p = kwargs.pop("top_p", 1.0)
|
41 |
-
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
42 |
-
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
43 |
-
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
44 |
-
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
|
45 |
-
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
46 |
-
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
47 |
-
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
48 |
-
self.output_scores = kwargs.pop("output_scores", False)
|
49 |
-
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
50 |
-
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
51 |
-
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
52 |
-
|
53 |
-
# Fine-tuning task arguments
|
54 |
-
self.architectures = kwargs.pop("architectures", None)
|
55 |
-
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
56 |
-
self.id2label = kwargs.pop("id2label", None)
|
57 |
-
self.label2id = kwargs.pop("label2id", None)
|
58 |
-
if self.id2label is not None:
|
59 |
-
kwargs.pop("num_labels", None)
|
60 |
-
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
61 |
-
# Keys are always strings in JSON so convert ids to int here.
|
62 |
-
else:
|
63 |
-
self.num_labels = kwargs.pop("num_labels", 2)
|
64 |
-
|
65 |
-
# Tokenizer arguments
|
66 |
-
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
67 |
-
self.prefix = kwargs.pop("prefix", None)
|
68 |
-
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
69 |
-
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
70 |
-
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
71 |
-
self.sep_token_id = kwargs.pop("sep_token_id", None)
|
72 |
-
|
73 |
-
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
74 |
-
|
75 |
-
# task specific arguments
|
76 |
-
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
77 |
-
|
78 |
-
# TPU arguments
|
79 |
-
self.xla_device = kwargs.pop("xla_device", None)
|
80 |
-
|
81 |
-
# Name or path to the pretrained checkpoint
|
82 |
-
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
83 |
-
|
84 |
-
# Drop the transformers version info
|
85 |
-
kwargs.pop("transformers_version", None)
|
86 |
-
|
87 |
-
# Additional attributes without default values
|
88 |
-
for key, value in kwargs.items():
|
89 |
-
try:
|
90 |
-
setattr(self, key, value)
|
91 |
-
except AttributeError as err:
|
92 |
-
raise err
|
93 |
-
|
94 |
-
@classmethod
|
95 |
-
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
96 |
-
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
97 |
-
return cls.from_dict(config_dict, **kwargs)
|
98 |
-
|
99 |
-
@classmethod
|
100 |
-
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
101 |
-
with open(json_file, "r", encoding="utf-8") as reader:
|
102 |
-
text = reader.read()
|
103 |
-
return json.loads(text)
|
104 |
-
|
105 |
-
@classmethod
|
106 |
-
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
|
107 |
-
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
108 |
-
|
109 |
-
config = cls(**config_dict)
|
110 |
-
|
111 |
-
if hasattr(config, "pruned_heads"):
|
112 |
-
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
113 |
-
|
114 |
-
# Update config with kwargs if needed
|
115 |
-
to_remove = []
|
116 |
-
for key, value in kwargs.items():
|
117 |
-
if hasattr(config, key):
|
118 |
-
setattr(config, key, value)
|
119 |
-
to_remove.append(key)
|
120 |
-
for key in to_remove:
|
121 |
-
kwargs.pop(key, None)
|
122 |
-
|
123 |
-
if return_unused_kwargs:
|
124 |
-
return config, kwargs
|
125 |
-
else:
|
126 |
-
return config
|
127 |
-
|
128 |
-
@classmethod
|
129 |
-
def get_config_dict(
|
130 |
-
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
131 |
-
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
132 |
-
cache_dir = kwargs.pop("cache_dir", None)
|
133 |
-
force_download = kwargs.pop("force_download", False)
|
134 |
-
resume_download = kwargs.pop("resume_download", False)
|
135 |
-
proxies = kwargs.pop("proxies", None)
|
136 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
137 |
-
local_files_only = kwargs.pop("local_files_only", False)
|
138 |
-
revision = kwargs.pop("revision", None)
|
139 |
-
|
140 |
-
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
141 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
142 |
-
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
143 |
-
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
144 |
-
config_file = pretrained_model_name_or_path
|
145 |
-
else:
|
146 |
-
config_file = hf_bucket_url(
|
147 |
-
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
|
148 |
-
)
|
149 |
-
|
150 |
-
try:
|
151 |
-
# Load from URL or cache if already cached
|
152 |
-
resolved_config_file = cached_path(
|
153 |
-
config_file,
|
154 |
-
cache_dir=cache_dir,
|
155 |
-
force_download=force_download,
|
156 |
-
proxies=proxies,
|
157 |
-
resume_download=resume_download,
|
158 |
-
local_files_only=local_files_only,
|
159 |
-
use_auth_token=use_auth_token,
|
160 |
-
)
|
161 |
-
# Load config dict
|
162 |
-
config_dict = cls._dict_from_json_file(resolved_config_file)
|
163 |
-
|
164 |
-
except EnvironmentError as err:
|
165 |
-
msg = (
|
166 |
-
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
167 |
-
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
168 |
-
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
|
169 |
-
)
|
170 |
-
raise EnvironmentError(msg)
|
171 |
-
|
172 |
-
except json.JSONDecodeError:
|
173 |
-
msg = (
|
174 |
-
"Couldn't reach server at '{}' to download configuration file or "
|
175 |
-
"configuration file is not a valid JSON file. "
|
176 |
-
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
|
177 |
-
)
|
178 |
-
raise EnvironmentError(msg)
|
179 |
-
|
180 |
-
return config_dict, kwargs
|
181 |
-
|
182 |
-
|
183 |
-
class BertConfig(PretrainedConfig):
|
184 |
-
model_type = "bert"
|
185 |
-
|
186 |
-
def __init__(
|
187 |
-
self,
|
188 |
-
vocab_size=30522,
|
189 |
-
hidden_size=768,
|
190 |
-
num_hidden_layers=12,
|
191 |
-
num_attention_heads=12,
|
192 |
-
intermediate_size=3072,
|
193 |
-
hidden_act="gelu",
|
194 |
-
hidden_dropout_prob=0.1,
|
195 |
-
attention_probs_dropout_prob=0.1,
|
196 |
-
max_position_embeddings=512,
|
197 |
-
type_vocab_size=2,
|
198 |
-
initializer_range=0.02,
|
199 |
-
layer_norm_eps=1e-12,
|
200 |
-
pad_token_id=0,
|
201 |
-
gradient_checkpointing=False,
|
202 |
-
position_embedding_type="absolute",
|
203 |
-
use_cache=True,
|
204 |
-
**kwargs
|
205 |
-
):
|
206 |
-
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
207 |
-
|
208 |
-
self.vocab_size = vocab_size
|
209 |
-
self.hidden_size = hidden_size
|
210 |
-
self.num_hidden_layers = num_hidden_layers
|
211 |
-
self.num_attention_heads = num_attention_heads
|
212 |
-
self.hidden_act = hidden_act
|
213 |
-
self.intermediate_size = intermediate_size
|
214 |
-
self.hidden_dropout_prob = hidden_dropout_prob
|
215 |
-
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
216 |
-
self.max_position_embeddings = max_position_embeddings
|
217 |
-
self.type_vocab_size = type_vocab_size
|
218 |
-
self.initializer_range = initializer_range
|
219 |
-
self.layer_norm_eps = layer_norm_eps
|
220 |
-
self.gradient_checkpointing = gradient_checkpointing
|
221 |
-
self.position_embedding_type = position_embedding_type
|
222 |
-
self.use_cache = use_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/amazon-polarity.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:dbe4770cfa6be45add6c9a322044bd4c1901520dde5a2707eca402a74fbe854e
|
3 |
-
size 870289
|
|
|
|
|
|
|
|
data/ids-cfimdb-dev.csv
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:3087f571b66860fe5d035b5a018d08202ad3fd3720e4821c04b2acf6c7ded559
|
3 |
-
size 249095
|
|
|
|
|
|
|
|
data/ids-cfimdb-test-student.csv
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:7ae611548c9eac879e9ebb406cc9f8ae68ff12f78090e4965af5cbdfa06240f4
|
3 |
-
size 495595
|
|
|
|
|
|
|
|
data/ids-cfimdb-train.csv
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:140fc513045a966109faed46a5c7a898767b96714d71bcb9c15f659129fadcea
|
3 |
-
size 1693182
|
|
|
|
|
|
|
|
data/ids-sst-dev.csv
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a186ce94577635fbe10beaaddd50f16cccf6c30973221cefdf90deed2a584bfe
|
3 |
-
size 151384
|
|
|
|
|
|
|
|
data/ids-sst-test-student.csv
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:bdd5a767faa0c26782117e37767ece154c30d5d04fb8727d09c71e3850a55c7b
|
3 |
-
size 313202
|
|
|
|
|
|
|
|
data/ids-sst-train.csv
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:03b2b625c090f94a6afd59f114cde5282e2053aab0b101e87ed695d8a0c5b1df
|
3 |
-
size 1175139
|
|
|
|
|
|
|
|
data/nli-dev.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c267496435885e724abc71e53669fae59db875bfa13389eab8f9b0b2dfb2b32e
|
3 |
-
size 782233
|
|
|
|
|
|
|
|
data/nli-test.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:01688df43ae4c019a86144a0d2351146b124688a55f285071cccd156225a5fdf
|
3 |
-
size 810423
|
|
|
|
|
|
|
|
data/nli-train.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:f9aeca80b1bda983ee316f854ebc37af8341877fb932dd6a2c6aba978ad112a5
|
3 |
-
size 38396324
|
|
|
|
|
|
|
|
data/stsb-dev.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:9c6e0e9881f1b398abe3e439a482f4686305c3784568c462f6bba58bdff03b0a
|
3 |
-
size 142187
|
|
|
|
|
|
|
|
data/stsb-test.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:8acbc291c50977d8655934952956016c3e049c2fe04f8a6c454c1bf6acc42ca1
|
3 |
-
size 108100
|
|
|
|
|
|
|
|
data/stsb-train.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:eae324ff1eac2d0ba769851736eb7232eda64f370a16eb20e74a2c5f8f5fafe0
|
3 |
-
size 470612
|
|
|
|
|
|
|
|
evaluation.py
DELETED
@@ -1,205 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
|
3 |
-
'''
|
4 |
-
Multitask BERT evaluation functions.
|
5 |
-
|
6 |
-
When training your multitask model, you will find it useful to call
|
7 |
-
model_eval_multitask to evaluate your model on the 3 tasks' dev sets.
|
8 |
-
'''
|
9 |
-
|
10 |
-
import torch
|
11 |
-
from sklearn.metrics import f1_score, accuracy_score
|
12 |
-
from tqdm import tqdm
|
13 |
-
import numpy as np
|
14 |
-
|
15 |
-
|
16 |
-
TQDM_DISABLE = False
|
17 |
-
|
18 |
-
|
19 |
-
# Evaluate multitask model on SST only.
|
20 |
-
def model_eval_sst(dataloader, model, device):
|
21 |
-
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
22 |
-
y_true = []
|
23 |
-
y_pred = []
|
24 |
-
sents = []
|
25 |
-
sent_ids = []
|
26 |
-
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
27 |
-
b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
|
28 |
-
batch['labels'], batch['sents'], batch['sent_ids']
|
29 |
-
|
30 |
-
b_ids = b_ids.to(device)
|
31 |
-
b_mask = b_mask.to(device)
|
32 |
-
|
33 |
-
logits = model.predict_sentiment(b_ids, b_mask)
|
34 |
-
logits = logits.detach().cpu().numpy()
|
35 |
-
preds = np.argmax(logits, axis=1).flatten()
|
36 |
-
|
37 |
-
b_labels = b_labels.flatten()
|
38 |
-
y_true.extend(b_labels)
|
39 |
-
y_pred.extend(preds)
|
40 |
-
sents.extend(b_sents)
|
41 |
-
sent_ids.extend(b_sent_ids)
|
42 |
-
|
43 |
-
f1 = f1_score(y_true, y_pred, average='macro')
|
44 |
-
acc = accuracy_score(y_true, y_pred)
|
45 |
-
|
46 |
-
return acc, f1, y_pred, y_true, sents, sent_ids
|
47 |
-
|
48 |
-
|
49 |
-
# Evaluate multitask model on dev sets.
|
50 |
-
def model_eval_multitask(sentiment_dataloader,
|
51 |
-
paraphrase_dataloader,
|
52 |
-
sts_dataloader,
|
53 |
-
model, device):
|
54 |
-
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
55 |
-
|
56 |
-
with torch.no_grad():
|
57 |
-
# Evaluate sentiment classification.
|
58 |
-
sst_y_true = []
|
59 |
-
sst_y_pred = []
|
60 |
-
sst_sent_ids = []
|
61 |
-
for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
62 |
-
b_ids, b_mask, b_labels, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['labels'], batch['sent_ids']
|
63 |
-
|
64 |
-
b_ids = b_ids.to(device)
|
65 |
-
b_mask = b_mask.to(device)
|
66 |
-
|
67 |
-
logits = model.predict_sentiment(b_ids, b_mask)
|
68 |
-
y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()
|
69 |
-
b_labels = b_labels.flatten().cpu().numpy()
|
70 |
-
|
71 |
-
sst_y_pred.extend(y_hat)
|
72 |
-
sst_y_true.extend(b_labels)
|
73 |
-
sst_sent_ids.extend(b_sent_ids)
|
74 |
-
|
75 |
-
sentiment_accuracy = np.mean(np.array(sst_y_pred) == np.array(sst_y_true))
|
76 |
-
|
77 |
-
# Evaluate paraphrase detection.
|
78 |
-
para_y_true = []
|
79 |
-
para_y_pred = []
|
80 |
-
para_sent_ids = []
|
81 |
-
for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
82 |
-
(b_ids1, b_mask1,
|
83 |
-
b_ids2, b_mask2,
|
84 |
-
b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
|
85 |
-
batch['token_ids_2'], batch['attention_mask_2'],
|
86 |
-
batch['labels'], batch['sent_ids'])
|
87 |
-
|
88 |
-
b_ids1 = b_ids1.to(device)
|
89 |
-
b_mask1 = b_mask1.to(device)
|
90 |
-
b_ids2 = b_ids2.to(device)
|
91 |
-
b_mask2 = b_mask2.to(device)
|
92 |
-
|
93 |
-
logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)
|
94 |
-
y_hat = logits.sigmoid().round().flatten().cpu().numpy()
|
95 |
-
b_labels = b_labels.flatten().cpu().numpy()
|
96 |
-
|
97 |
-
para_y_pred.extend(y_hat)
|
98 |
-
para_y_true.extend(b_labels)
|
99 |
-
para_sent_ids.extend(b_sent_ids)
|
100 |
-
|
101 |
-
paraphrase_accuracy = np.mean(np.array(para_y_pred) == np.array(para_y_true))
|
102 |
-
|
103 |
-
# Evaluate semantic textual similarity.
|
104 |
-
sts_y_true = []
|
105 |
-
sts_y_pred = []
|
106 |
-
sts_sent_ids = []
|
107 |
-
for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
108 |
-
(b_ids1, b_mask1,
|
109 |
-
b_ids2, b_mask2,
|
110 |
-
b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
|
111 |
-
batch['token_ids_2'], batch['attention_mask_2'],
|
112 |
-
batch['labels'], batch['sent_ids'])
|
113 |
-
|
114 |
-
b_ids1 = b_ids1.to(device)
|
115 |
-
b_mask1 = b_mask1.to(device)
|
116 |
-
b_ids2 = b_ids2.to(device)
|
117 |
-
b_mask2 = b_mask2.to(device)
|
118 |
-
|
119 |
-
logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)
|
120 |
-
y_hat = logits.flatten().cpu().numpy()
|
121 |
-
b_labels = b_labels.flatten().cpu().numpy()
|
122 |
-
|
123 |
-
sts_y_pred.extend(y_hat)
|
124 |
-
sts_y_true.extend(b_labels)
|
125 |
-
sts_sent_ids.extend(b_sent_ids)
|
126 |
-
pearson_mat = np.corrcoef(sts_y_pred,sts_y_true)
|
127 |
-
sts_corr = pearson_mat[1][0]
|
128 |
-
|
129 |
-
print(f'Sentiment classification accuracy: {sentiment_accuracy:.3f}')
|
130 |
-
print(f'Paraphrase detection accuracy: {paraphrase_accuracy:.3f}')
|
131 |
-
print(f'Semantic Textual Similarity correlation: {sts_corr:.3f}')
|
132 |
-
|
133 |
-
return (sentiment_accuracy,sst_y_pred, sst_sent_ids,
|
134 |
-
paraphrase_accuracy, para_y_pred, para_sent_ids,
|
135 |
-
sts_corr, sts_y_pred, sts_sent_ids)
|
136 |
-
|
137 |
-
|
138 |
-
# Evaluate multitask model on test sets.
|
139 |
-
def model_eval_test_multitask(sentiment_dataloader,
|
140 |
-
paraphrase_dataloader,
|
141 |
-
sts_dataloader,
|
142 |
-
model, device):
|
143 |
-
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
144 |
-
|
145 |
-
with torch.no_grad():
|
146 |
-
# Evaluate sentiment classification.
|
147 |
-
sst_y_pred = []
|
148 |
-
sst_sent_ids = []
|
149 |
-
for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
150 |
-
b_ids, b_mask, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['sent_ids']
|
151 |
-
|
152 |
-
b_ids = b_ids.to(device)
|
153 |
-
b_mask = b_mask.to(device)
|
154 |
-
|
155 |
-
logits = model.predict_sentiment(b_ids, b_mask)
|
156 |
-
y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()
|
157 |
-
|
158 |
-
sst_y_pred.extend(y_hat)
|
159 |
-
sst_sent_ids.extend(b_sent_ids)
|
160 |
-
|
161 |
-
# Evaluate paraphrase detection.
|
162 |
-
para_y_pred = []
|
163 |
-
para_sent_ids = []
|
164 |
-
for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
165 |
-
(b_ids1, b_mask1,
|
166 |
-
b_ids2, b_mask2,
|
167 |
-
b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
|
168 |
-
batch['token_ids_2'], batch['attention_mask_2'],
|
169 |
-
batch['sent_ids'])
|
170 |
-
|
171 |
-
b_ids1 = b_ids1.to(device)
|
172 |
-
b_mask1 = b_mask1.to(device)
|
173 |
-
b_ids2 = b_ids2.to(device)
|
174 |
-
b_mask2 = b_mask2.to(device)
|
175 |
-
|
176 |
-
logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)
|
177 |
-
y_hat = logits.sigmoid().round().flatten().cpu().numpy()
|
178 |
-
|
179 |
-
para_y_pred.extend(y_hat)
|
180 |
-
para_sent_ids.extend(b_sent_ids)
|
181 |
-
|
182 |
-
# Evaluate semantic textual similarity.
|
183 |
-
sts_y_pred = []
|
184 |
-
sts_sent_ids = []
|
185 |
-
for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
186 |
-
(b_ids1, b_mask1,
|
187 |
-
b_ids2, b_mask2,
|
188 |
-
b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
|
189 |
-
batch['token_ids_2'], batch['attention_mask_2'],
|
190 |
-
batch['sent_ids'])
|
191 |
-
|
192 |
-
b_ids1 = b_ids1.to(device)
|
193 |
-
b_mask1 = b_mask1.to(device)
|
194 |
-
b_ids2 = b_ids2.to(device)
|
195 |
-
b_mask2 = b_mask2.to(device)
|
196 |
-
|
197 |
-
logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)
|
198 |
-
y_hat = logits.flatten().cpu().numpy()
|
199 |
-
|
200 |
-
sts_y_pred.extend(y_hat)
|
201 |
-
sts_sent_ids.extend(b_sent_ids)
|
202 |
-
|
203 |
-
return (sst_y_pred, sst_sent_ids,
|
204 |
-
para_y_pred, para_sent_ids,
|
205 |
-
sts_y_pred, sts_sent_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
justfile
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
# Testing on Google Cloud VM with no GPU and 8 cpu cores.
|
2 |
-
|
3 |
-
# If this doesn't meet your need, add the --use_gpu
|
4 |
-
# or --num-cpu-cores arguments to the existing commands.
|
5 |
-
|
6 |
-
|
7 |
-
default:
|
8 |
-
@just --list
|
9 |
-
|
10 |
-
last-linear:
|
11 |
-
python classifier.py
|
12 |
-
|
13 |
-
full-model:
|
14 |
-
python classifier.py --fine-tune-mode full-model --lr 1e-5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
from typing import Callable, Iterable, Tuple
|
2 |
-
import math
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch.optim import Optimizer
|
6 |
-
|
7 |
-
|
8 |
-
class AdamW(Optimizer):
|
9 |
-
def __init__(
|
10 |
-
self,
|
11 |
-
params: Iterable[torch.nn.parameter.Parameter],
|
12 |
-
lr: float = 1e-3,
|
13 |
-
betas: Tuple[float, float] = (0.9, 0.999),
|
14 |
-
eps: float = 1e-6,
|
15 |
-
weight_decay: float = 0.0,
|
16 |
-
correct_bias: bool = True,
|
17 |
-
):
|
18 |
-
if lr < 0.0:
|
19 |
-
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
20 |
-
if not 0.0 <= betas[0] < 1.0:
|
21 |
-
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
|
22 |
-
if not 0.0 <= betas[1] < 1.0:
|
23 |
-
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
|
24 |
-
if not 0.0 <= eps:
|
25 |
-
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
|
26 |
-
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
|
27 |
-
super().__init__(params, defaults)
|
28 |
-
|
29 |
-
def step(self, closure: Callable = None):
|
30 |
-
loss = None
|
31 |
-
if closure is not None:
|
32 |
-
loss = closure()
|
33 |
-
|
34 |
-
for group in self.param_groups:
|
35 |
-
for p in group["params"]:
|
36 |
-
if p.grad is None:
|
37 |
-
continue
|
38 |
-
grad = p.grad.data
|
39 |
-
if grad.is_sparse:
|
40 |
-
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
41 |
-
|
42 |
-
# Access state
|
43 |
-
state = self.state[p]
|
44 |
-
|
45 |
-
# Initialize state if not already done
|
46 |
-
if len(state) == 0:
|
47 |
-
state["step"] = 0
|
48 |
-
state["exp_avg"] = torch.zeros_like(p.data)
|
49 |
-
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
50 |
-
|
51 |
-
# Hyperparameters
|
52 |
-
alpha = group["lr"]
|
53 |
-
beta1, beta2 = group["betas"]
|
54 |
-
eps = group["eps"]
|
55 |
-
weight_decay = group["weight_decay"]
|
56 |
-
correct_bias = group["correct_bias"]
|
57 |
-
|
58 |
-
# Retrieve state variables
|
59 |
-
exp_avg = state["exp_avg"]
|
60 |
-
exp_avg_sq = state["exp_avg_sq"]
|
61 |
-
step = state["step"]
|
62 |
-
|
63 |
-
# Update step
|
64 |
-
step += 1
|
65 |
-
state["step"] = step
|
66 |
-
|
67 |
-
# Update biased first and second moment estimates
|
68 |
-
exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1))
|
69 |
-
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
70 |
-
|
71 |
-
# Compute bias-corrected moments
|
72 |
-
if correct_bias:
|
73 |
-
bias_correction1 = 1 - beta1 ** step
|
74 |
-
bias_correction2 = 1 - beta2 ** step
|
75 |
-
exp_avg_corr = exp_avg / bias_correction1
|
76 |
-
exp_avg_sq_corr = exp_avg_sq / bias_correction2
|
77 |
-
else:
|
78 |
-
exp_avg_corr = exp_avg
|
79 |
-
exp_avg_sq_corr = exp_avg_sq
|
80 |
-
|
81 |
-
# Update parameters
|
82 |
-
denom = exp_avg_sq_corr.sqrt().add_(eps)
|
83 |
-
step_size = alpha
|
84 |
-
p.data.addcdiv_(exp_avg_corr, denom, value=-step_size)
|
85 |
-
|
86 |
-
# Apply weight decay
|
87 |
-
if weight_decay != 0:
|
88 |
-
p.data.add_(p.data, alpha=-alpha * weight_decay)
|
89 |
-
|
90 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer_test.npy
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:77b817e0dce16a9bc8d3a6bcb88035db68f7d783dc8a565737581fadd05db815
|
3 |
-
size 152
|
|
|
|
|
|
|
|
optimizer_test.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
-
from optimizer import AdamW
|
4 |
-
|
5 |
-
seed = 0
|
6 |
-
|
7 |
-
|
8 |
-
def test_optimizer(opt_class) -> torch.Tensor:
|
9 |
-
rng = np.random.default_rng(seed)
|
10 |
-
torch.manual_seed(seed)
|
11 |
-
model = torch.nn.Linear(3, 2, bias=False)
|
12 |
-
opt = opt_class(
|
13 |
-
model.parameters(),
|
14 |
-
lr=1e-3,
|
15 |
-
weight_decay=1e-4,
|
16 |
-
correct_bias=True,
|
17 |
-
)
|
18 |
-
for i in range(1000):
|
19 |
-
opt.zero_grad()
|
20 |
-
x = torch.FloatTensor(rng.uniform(size=[model.in_features]))
|
21 |
-
y_hat = model(x)
|
22 |
-
y = torch.Tensor([x[0] + x[1], -x[2]])
|
23 |
-
loss = ((y - y_hat) ** 2).sum()
|
24 |
-
loss.backward()
|
25 |
-
opt.step()
|
26 |
-
return model.weight.detach()
|
27 |
-
|
28 |
-
|
29 |
-
ref = torch.tensor(np.load("optimizer_test.npy"))
|
30 |
-
actual = test_optimizer(AdamW)
|
31 |
-
print(ref)
|
32 |
-
print(actual)
|
33 |
-
assert torch.allclose(ref, actual, atol=1e-6, rtol=1e-4)
|
34 |
-
print("Optimizer test passed!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predictions/README
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
By default, `classifier.py` and `multitask_classifier.py` write your model predictions into this folder.
|
2 |
-
Before running prepare_submit.py, make sure that this directory has been populated!
|
|
|
|
|
|
prepare_submit.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
# Creates a zip file for submission on Gradescope.
|
2 |
-
|
3 |
-
import os
|
4 |
-
import zipfile
|
5 |
-
|
6 |
-
required_files = [p for p in os.listdir('.') if p.endswith('.py')] + \
|
7 |
-
[f'predictions/{p}' for p in os.listdir('predictions')]
|
8 |
-
|
9 |
-
def main():
|
10 |
-
aid = 'cs224n_default_final_project_submission'
|
11 |
-
path = os.getcwd()
|
12 |
-
with zipfile.ZipFile(f"{aid}.zip", 'w') as zz:
|
13 |
-
for file in required_files:
|
14 |
-
zz.write(file, os.path.join(".", file))
|
15 |
-
print(f"Submission zip file created: {aid}.zip")
|
16 |
-
|
17 |
-
if __name__ == '__main__':
|
18 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
Tôi muốn finetune minBERT bằng phương pháp Unsupervised SimCSE để thực hiện sentiment analysis nhưng chưa biết phải làm như thế nào. theo như tôi hiểu thì tôi sẽ finetune mô hình minBERT bằng SimCSE để có được embeddings tốt hơn, sau đó sẽ dùng embeddings này để truyền qua SentimentClassifier để phân loại. Tuy nhiên, hướng tiếp cận đúng đắn là gì?
|
2 |
-
Tôi đã nghĩ đến hai cách sau đây (hoặc có thể cách khác nhưng chưa nghĩ ra). Bạn xem xét thử nhé!
|
3 |
-
1. Finetune minBERT bằng SimCSE trước rồi mới finetune SentimentClassifier: sử dụng dataset STS-B hoặc Twitter Sentiment Dataset để finetune minBERT, rồi đánh giá độ
|
|
|
|
|
|
|
|
sanity_check.data
DELETED
Binary file (56.4 kB)
|
|
sanity_check.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from bert import BertModel
|
3 |
-
|
4 |
-
|
5 |
-
sanity_data = torch.load("./sanity_check.data", weights_only=True)
|
6 |
-
sent_ids = torch.tensor([[101, 7592, 2088, 102, 0, 0, 0, 0],
|
7 |
-
[101, 7592, 15756, 2897, 2005, 17953, 2361, 102]])
|
8 |
-
att_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1]])
|
9 |
-
|
10 |
-
# Load model.
|
11 |
-
bert = BertModel.from_pretrained('bert-base-uncased')
|
12 |
-
outputs = bert(sent_ids, att_mask)
|
13 |
-
att_mask = att_mask.unsqueeze(-1)
|
14 |
-
outputs['last_hidden_state'] = outputs['last_hidden_state'] * att_mask
|
15 |
-
sanity_data['last_hidden_state'] = sanity_data['last_hidden_state'] * att_mask
|
16 |
-
|
17 |
-
for k in ['last_hidden_state', 'pooler_output']:
|
18 |
-
assert torch.allclose(outputs[k], sanity_data[k], atol=1e-5, rtol=1e-3)
|
19 |
-
print("Your BERT implementation is correct!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup.sh
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
conda create -n cs224n_dfp python=3.8
|
4 |
-
conda activate cs224n_dfp
|
5 |
-
|
6 |
-
pip install torch torchvision torchaudio
|
7 |
-
pip install tqdm==4.58.0
|
8 |
-
pip install requests==2.25.1
|
9 |
-
pip install importlib-metadata==3.7.0
|
10 |
-
pip install filelock==3.0.12
|
11 |
-
pip install sklearn==0.0
|
12 |
-
pip install tokenizers==0.15
|
13 |
-
pip install explainaboard_client==0.0.7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
unsup_simcse.py
DELETED
@@ -1,252 +0,0 @@
|
|
1 |
-
import csv
|
2 |
-
import torch
|
3 |
-
import random
|
4 |
-
import argparse
|
5 |
-
import numpy as np
|
6 |
-
import pandas as pd
|
7 |
-
import torch.nn.functional as F
|
8 |
-
|
9 |
-
from tqdm import tqdm
|
10 |
-
from torch import Tensor
|
11 |
-
from types import SimpleNamespace
|
12 |
-
from torch.utils.data import Dataset, DataLoader
|
13 |
-
from sklearn.metrics import f1_score, accuracy_score
|
14 |
-
|
15 |
-
from bert import BertModel
|
16 |
-
from optimizer import AdamW
|
17 |
-
from classifier import seed_everything, tokenizer
|
18 |
-
from classifier import SentimentDataset, BertSentimentClassifier
|
19 |
-
|
20 |
-
|
21 |
-
TQDM_DISABLE = False
|
22 |
-
|
23 |
-
|
24 |
-
class AmazonDataset(Dataset):
|
25 |
-
def __init__(self, dataset, args):
|
26 |
-
self.dataset = dataset
|
27 |
-
self.p = args
|
28 |
-
|
29 |
-
def __len__(self):
|
30 |
-
return len(self.dataset)
|
31 |
-
|
32 |
-
def __getitem__(self, idx):
|
33 |
-
return self.dataset[idx]
|
34 |
-
|
35 |
-
def pad_data(self, data):
|
36 |
-
sents = [x[0] for x in data]
|
37 |
-
sent_ids = [x[1] for x in data]
|
38 |
-
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
39 |
-
token_ids = torch.LongTensor(encoding['input_ids'])
|
40 |
-
attension_mask = torch.LongTensor(encoding['attention_mask'])
|
41 |
-
|
42 |
-
return token_ids, attension_mask, sent_ids
|
43 |
-
|
44 |
-
def collate_fn(self, data):
|
45 |
-
token_ids, attention_mask, sent_ids = self.pad_data(data)
|
46 |
-
|
47 |
-
batched_data = {
|
48 |
-
'token_ids': token_ids,
|
49 |
-
'attention_mask': attention_mask,
|
50 |
-
'sent_ids': sent_ids
|
51 |
-
}
|
52 |
-
|
53 |
-
return batched_data
|
54 |
-
|
55 |
-
|
56 |
-
def load_data(filename, flag='train'):
|
57 |
-
'''
|
58 |
-
- for amazon dataset: list of (sent, sent_id)
|
59 |
-
- for test dataset: list of (sent, sent_id)
|
60 |
-
- for train dataset: list of (sent, label, sent_id)
|
61 |
-
'''
|
62 |
-
|
63 |
-
if flag == 'amazon':
|
64 |
-
df = pd.read_parquet(filename)
|
65 |
-
data = list(zip(df['content'], df.index))
|
66 |
-
else:
|
67 |
-
data, num_labels = [], set()
|
68 |
-
|
69 |
-
with open(filename, 'r') as fp:
|
70 |
-
if flag == 'test':
|
71 |
-
for record in csv.DictReader(fp, delimiter = '\t'):
|
72 |
-
sent = record['sentence'].lower().strip()
|
73 |
-
sent_id = record['id'].lower().strip()
|
74 |
-
data.append((sent,sent_id))
|
75 |
-
else:
|
76 |
-
for record in csv.DictReader(fp, delimiter = '\t'):
|
77 |
-
sent = record['sentence'].lower().strip()
|
78 |
-
sent_id = record['id'].lower().strip()
|
79 |
-
label = int(record['sentiment'].strip())
|
80 |
-
num_labels.add(label)
|
81 |
-
data.append((sent, label, sent_id))
|
82 |
-
|
83 |
-
print(f"load {len(data)} data from {filename}")
|
84 |
-
if flag in ['test', 'amazon']:
|
85 |
-
return data
|
86 |
-
else:
|
87 |
-
return data, len(num_labels)
|
88 |
-
|
89 |
-
|
90 |
-
def save_model(model, optimizer, args, config, filepath):
|
91 |
-
save_info = {
|
92 |
-
'model': model.state_dict(),
|
93 |
-
'optim': optimizer.state_dict(),
|
94 |
-
'args': args,
|
95 |
-
'model_config': config,
|
96 |
-
'system_rng': random.getstate(),
|
97 |
-
'numpy_rng': np.random.get_state(),
|
98 |
-
'torch_rng': torch.random.get_rng_state(),
|
99 |
-
}
|
100 |
-
|
101 |
-
torch.save(save_info, filepath)
|
102 |
-
print(f"save the model to {filepath}")
|
103 |
-
|
104 |
-
|
105 |
-
def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
|
106 |
-
'''
|
107 |
-
embeds_1: [batch_size, hidden_size]
|
108 |
-
embeds_2: [batch_size, hidden_size]
|
109 |
-
'''
|
110 |
-
|
111 |
-
# [batch_size, batch_size]
|
112 |
-
sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp
|
113 |
-
|
114 |
-
# [batch_size]
|
115 |
-
positive_sim = torch.diagonal(sim_matrix)
|
116 |
-
|
117 |
-
# [batch_size]
|
118 |
-
nume = torch.exp(positive_sim)
|
119 |
-
|
120 |
-
# [batch_size]
|
121 |
-
deno = torch.exp(sim_matrix).sum(1)
|
122 |
-
|
123 |
-
# [batch_size]
|
124 |
-
loss_per_batch = -torch.log(nume / deno)
|
125 |
-
|
126 |
-
return loss_per_batch.mean()
|
127 |
-
|
128 |
-
|
129 |
-
def train(args):
|
130 |
-
'''
|
131 |
-
Training Pipeline
|
132 |
-
-----------------
|
133 |
-
1. Load the Amazon Polarity and SST Dataset.
|
134 |
-
2. Determine batch_size (64) and number of batches (?).
|
135 |
-
3. Initialize SentimentClassifier (including bert).
|
136 |
-
4. Looping through 10 epoches.
|
137 |
-
5. Finetune minBERT with SimCSE loss function.
|
138 |
-
6. Finetune Classifier with cross-entropy function.
|
139 |
-
7. Backpropagation using Adam Optimizer for both.
|
140 |
-
8. Evaluating the model on dev dataset.
|
141 |
-
9. If dev_acc > best_dev_acc: save_model(...)
|
142 |
-
'''
|
143 |
-
|
144 |
-
amazon_data = load_data(args.train_bert, 'amazon')
|
145 |
-
train_data, num_labels = load_data(args.train, 'train')
|
146 |
-
dev_data = load_data(args.dev, 'valid')
|
147 |
-
|
148 |
-
amazon_dataset = AmazonDataset(amazon_data, args)
|
149 |
-
train_dataset = SentimentDataset(train_data, args)
|
150 |
-
dev_dataset = SentimentDataset(dev_data, args)
|
151 |
-
|
152 |
-
amazon_dataloader = DataLoader(amazon_dataset, shuffle=True, batch_size=args.batch_size_cse,
|
153 |
-
num_workers=args.num_cpu_cores, collate_fn=amazon_dataset.collate_fn)
|
154 |
-
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
|
155 |
-
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
|
156 |
-
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
|
157 |
-
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
|
158 |
-
|
159 |
-
config = SimpleNamespace(
|
160 |
-
hidden_dropout_prob=args.hidden_dropout_prob,
|
161 |
-
num_labels=num_labels,
|
162 |
-
hidden_size=768,
|
163 |
-
data_dir='.',
|
164 |
-
fine_tune_mode='full-model'
|
165 |
-
)
|
166 |
-
|
167 |
-
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
168 |
-
model = BertSentimentClassifier(config)
|
169 |
-
model = model.to(device)
|
170 |
-
|
171 |
-
optimizer_cse = AdamW(model.bert.parameters(), lr=args.lr_cse)
|
172 |
-
optimizer_classifier = AdamW(model.parameters(), lr=args.lr_classifier)
|
173 |
-
best_dev_acc = 0
|
174 |
-
|
175 |
-
# ---- Training minBERT using SimCSE ---- #
|
176 |
-
for epoch in range(args.epochs):
|
177 |
-
model.bert.train()
|
178 |
-
train_loss = num_batches = 0
|
179 |
-
for batch in tqdm(amazon_dataloader, f'train-amazon-{epoch}', leave=False, disable=TQDM_DISABLE):
|
180 |
-
b_ids, b_mask = batch['token_ids'], batch['attention_mask']
|
181 |
-
b_ids = b_ids.to(device)
|
182 |
-
b_mask = b_mask.to(device)
|
183 |
-
|
184 |
-
# Get different embeddings with different dropout masks
|
185 |
-
logits_1 = model.bert(b_ids, b_mask)['pooler_output']
|
186 |
-
logits_2 = model.bert(b_ids, b_mask)['pooler_output']
|
187 |
-
|
188 |
-
# Calculate mean SimCSE loss function
|
189 |
-
loss = contrastive_loss(logits_1, logits_2)
|
190 |
-
|
191 |
-
# Back propagation
|
192 |
-
optimizer_cse.zero_grad()
|
193 |
-
loss.backward()
|
194 |
-
optimizer_cse.step()
|
195 |
-
|
196 |
-
train_loss += loss.item()
|
197 |
-
num_batches += 1
|
198 |
-
|
199 |
-
train_loss = train_loss / num_batches
|
200 |
-
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}")
|
201 |
-
|
202 |
-
|
203 |
-
def get_args():
|
204 |
-
parser = argparse.ArgumentParser()
|
205 |
-
parser.add_argument("--seed", type=int, default=11711)
|
206 |
-
parser.add_argument("--num-cpu-cores", type=int, default=8)
|
207 |
-
parser.add_argument("--epochs", type=int, default=10)
|
208 |
-
parser.add_argument("--use_gpu", action='store_true')
|
209 |
-
parser.add_argument("--batch_size_cse", type=int, default=8)
|
210 |
-
parser.add_argument("--batch_size_sst", type=int, default=64)
|
211 |
-
parser.add_argument("--batch_size_cfimdb", type=int, default=8)
|
212 |
-
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
213 |
-
parser.add_argument("--lr_cse", type=float, default=1e-5)
|
214 |
-
parser.add_argument("--lr_classifier", type=float, default=1e-5)
|
215 |
-
|
216 |
-
args = parser.parse_args()
|
217 |
-
return args
|
218 |
-
|
219 |
-
|
220 |
-
if __name__ == "__main__":
|
221 |
-
args = get_args()
|
222 |
-
seed_everything(args.seed)
|
223 |
-
torch.set_num_threads(args.num_cpu_cores)
|
224 |
-
|
225 |
-
print('Finetuning minBERT with Unsupervised SimCSE...')
|
226 |
-
config = SimpleNamespace(
|
227 |
-
filepath='contrastive-nli.pt',
|
228 |
-
lr_cse=args.lr_cse,
|
229 |
-
lr_classifier=args.lr_classifier,
|
230 |
-
num_cpu_cores=args.num_cpu_cores,
|
231 |
-
use_gpu=args.use_gpu,
|
232 |
-
epochs=args.epochs,
|
233 |
-
batch_size_cse=args.batch_size_cse,
|
234 |
-
batch_size_classifier=args.batch_size_sst,
|
235 |
-
hidden_dropout_prob=args.hidden_dropout_prob,
|
236 |
-
train_bert='data/amazon-polarity.parquet',
|
237 |
-
train='data/ids-sst-train.csv',
|
238 |
-
dev='data/ids-sst-dev.csv',
|
239 |
-
test='data/ids-sst-test-student.csv'
|
240 |
-
)
|
241 |
-
|
242 |
-
train(config)
|
243 |
-
|
244 |
-
# model = BertModel.from_pretrained('bert-base-uncased')
|
245 |
-
|
246 |
-
# model.eval()
|
247 |
-
|
248 |
-
# s = set()
|
249 |
-
# for param in model.parameters():
|
250 |
-
# s.add(param.requires_grad)
|
251 |
-
|
252 |
-
# print(s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
DELETED
@@ -1,347 +0,0 @@
|
|
1 |
-
from typing import Dict, List, Optional, Union, Tuple, BinaryIO
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
-
import json
|
5 |
-
import tempfile
|
6 |
-
import copy
|
7 |
-
from tqdm.auto import tqdm
|
8 |
-
from functools import partial
|
9 |
-
from urllib.parse import urlparse
|
10 |
-
from pathlib import Path
|
11 |
-
import requests
|
12 |
-
from hashlib import sha256
|
13 |
-
from filelock import FileLock
|
14 |
-
import importlib_metadata
|
15 |
-
import torch
|
16 |
-
import torch.nn as nn
|
17 |
-
from torch import Tensor
|
18 |
-
import fnmatch
|
19 |
-
|
20 |
-
__version__ = "4.0.0"
|
21 |
-
_torch_version = importlib_metadata.version("torch")
|
22 |
-
|
23 |
-
hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")))
|
24 |
-
default_cache_path = os.path.join(hf_cache_home, "transformers")
|
25 |
-
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
26 |
-
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
27 |
-
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
28 |
-
|
29 |
-
PRESET_MIRROR_DICT = {
|
30 |
-
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
|
31 |
-
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
32 |
-
}
|
33 |
-
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
|
34 |
-
WEIGHTS_NAME = "pytorch_model.bin"
|
35 |
-
CONFIG_NAME = "config.json"
|
36 |
-
|
37 |
-
|
38 |
-
def is_torch_available():
|
39 |
-
return True
|
40 |
-
|
41 |
-
|
42 |
-
def is_tf_available():
|
43 |
-
return False
|
44 |
-
|
45 |
-
|
46 |
-
def is_remote_url(url_or_filename):
|
47 |
-
parsed = urlparse(url_or_filename)
|
48 |
-
return parsed.scheme in ("http", "https")
|
49 |
-
|
50 |
-
|
51 |
-
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
52 |
-
headers = copy.deepcopy(headers)
|
53 |
-
if resume_size > 0:
|
54 |
-
headers["Range"] = "bytes=%d-" % (resume_size,)
|
55 |
-
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
56 |
-
r.raise_for_status()
|
57 |
-
content_length = r.headers.get("Content-Length")
|
58 |
-
total = resume_size + int(content_length) if content_length is not None else None
|
59 |
-
progress = tqdm(
|
60 |
-
unit="B",
|
61 |
-
unit_scale=True,
|
62 |
-
total=total,
|
63 |
-
initial=resume_size,
|
64 |
-
desc="Downloading",
|
65 |
-
disable=False,
|
66 |
-
)
|
67 |
-
for chunk in r.iter_content(chunk_size=1024):
|
68 |
-
if chunk: # filter out keep-alive new chunks
|
69 |
-
progress.update(len(chunk))
|
70 |
-
temp_file.write(chunk)
|
71 |
-
progress.close()
|
72 |
-
|
73 |
-
|
74 |
-
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
|
75 |
-
url_bytes = url.encode("utf-8")
|
76 |
-
filename = sha256(url_bytes).hexdigest()
|
77 |
-
|
78 |
-
if etag:
|
79 |
-
etag_bytes = etag.encode("utf-8")
|
80 |
-
filename += "." + sha256(etag_bytes).hexdigest()
|
81 |
-
|
82 |
-
if url.endswith(".h5"):
|
83 |
-
filename += ".h5"
|
84 |
-
|
85 |
-
return filename
|
86 |
-
|
87 |
-
|
88 |
-
def hf_bucket_url(
|
89 |
-
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
|
90 |
-
) -> str:
|
91 |
-
if subfolder is not None:
|
92 |
-
filename = f"{subfolder}/{filename}"
|
93 |
-
|
94 |
-
if mirror:
|
95 |
-
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
|
96 |
-
legacy_format = "/" not in model_id
|
97 |
-
if legacy_format:
|
98 |
-
return f"{endpoint}/{model_id}-{filename}"
|
99 |
-
else:
|
100 |
-
return f"{endpoint}/{model_id}/{filename}"
|
101 |
-
|
102 |
-
if revision is None:
|
103 |
-
revision = "main"
|
104 |
-
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
|
105 |
-
|
106 |
-
|
107 |
-
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
108 |
-
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
109 |
-
if is_torch_available():
|
110 |
-
ua += f"; torch/{_torch_version}"
|
111 |
-
if is_tf_available():
|
112 |
-
ua += f"; tensorflow/{_tf_version}"
|
113 |
-
if isinstance(user_agent, dict):
|
114 |
-
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
115 |
-
elif isinstance(user_agent, str):
|
116 |
-
ua += "; " + user_agent
|
117 |
-
return ua
|
118 |
-
|
119 |
-
|
120 |
-
def get_from_cache(
|
121 |
-
url: str,
|
122 |
-
cache_dir=None,
|
123 |
-
force_download=False,
|
124 |
-
proxies=None,
|
125 |
-
etag_timeout=10,
|
126 |
-
resume_download=False,
|
127 |
-
user_agent: Union[Dict, str, None] = None,
|
128 |
-
use_auth_token: Union[bool, str, None] = None,
|
129 |
-
local_files_only=False,
|
130 |
-
) -> Optional[str]:
|
131 |
-
if cache_dir is None:
|
132 |
-
cache_dir = TRANSFORMERS_CACHE
|
133 |
-
if isinstance(cache_dir, Path):
|
134 |
-
cache_dir = str(cache_dir)
|
135 |
-
|
136 |
-
os.makedirs(cache_dir, exist_ok=True)
|
137 |
-
|
138 |
-
headers = {"user-agent": http_user_agent(user_agent)}
|
139 |
-
if isinstance(use_auth_token, str):
|
140 |
-
headers["authorization"] = "Bearer {}".format(use_auth_token)
|
141 |
-
elif use_auth_token:
|
142 |
-
token = HfFolder.get_token()
|
143 |
-
if token is None:
|
144 |
-
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
145 |
-
headers["authorization"] = "Bearer {}".format(token)
|
146 |
-
|
147 |
-
url_to_download = url
|
148 |
-
etag = None
|
149 |
-
if not local_files_only:
|
150 |
-
try:
|
151 |
-
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
152 |
-
r.raise_for_status()
|
153 |
-
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
154 |
-
# We favor a custom header indicating the etag of the linked resource, and
|
155 |
-
# we fallback to the regular etag header.
|
156 |
-
# If we don't have any of those, raise an error.
|
157 |
-
if etag is None:
|
158 |
-
raise OSError(
|
159 |
-
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
160 |
-
)
|
161 |
-
# In case of a redirect,
|
162 |
-
# save an extra redirect on the request.get call,
|
163 |
-
# and ensure we download the exact atomic version even if it changed
|
164 |
-
# between the HEAD and the GET (unlikely, but hey).
|
165 |
-
if 300 <= r.status_code <= 399:
|
166 |
-
url_to_download = r.headers["Location"]
|
167 |
-
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
168 |
-
# etag is already None
|
169 |
-
pass
|
170 |
-
|
171 |
-
filename = url_to_filename(url, etag)
|
172 |
-
|
173 |
-
# get cache path to put the file
|
174 |
-
cache_path = os.path.join(cache_dir, filename)
|
175 |
-
|
176 |
-
# etag is None == we don't have a connection or we passed local_files_only.
|
177 |
-
# try to get the last downloaded one
|
178 |
-
if etag is None:
|
179 |
-
if os.path.exists(cache_path):
|
180 |
-
return cache_path
|
181 |
-
else:
|
182 |
-
matching_files = [
|
183 |
-
file
|
184 |
-
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
|
185 |
-
if not file.endswith(".json") and not file.endswith(".lock")
|
186 |
-
]
|
187 |
-
if len(matching_files) > 0:
|
188 |
-
return os.path.join(cache_dir, matching_files[-1])
|
189 |
-
else:
|
190 |
-
# If files cannot be found and local_files_only=True,
|
191 |
-
# the models might've been found if local_files_only=False
|
192 |
-
# Notify the user about that
|
193 |
-
if local_files_only:
|
194 |
-
raise FileNotFoundError(
|
195 |
-
"Cannot find the requested files in the cached path and outgoing traffic has been"
|
196 |
-
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
197 |
-
" to False."
|
198 |
-
)
|
199 |
-
else:
|
200 |
-
raise ValueError(
|
201 |
-
"Connection error, and we cannot find the requested files in the cached path."
|
202 |
-
" Please try again or make sure your Internet connection is on."
|
203 |
-
)
|
204 |
-
|
205 |
-
# From now on, etag is not None.
|
206 |
-
if os.path.exists(cache_path) and not force_download:
|
207 |
-
return cache_path
|
208 |
-
|
209 |
-
# Prevent parallel downloads of the same file with a lock.
|
210 |
-
lock_path = cache_path + ".lock"
|
211 |
-
with FileLock(lock_path):
|
212 |
-
|
213 |
-
# If the download just completed while the lock was activated.
|
214 |
-
if os.path.exists(cache_path) and not force_download:
|
215 |
-
# Even if returning early like here, the lock will be released.
|
216 |
-
return cache_path
|
217 |
-
|
218 |
-
if resume_download:
|
219 |
-
incomplete_path = cache_path + ".incomplete"
|
220 |
-
|
221 |
-
@contextmanager
|
222 |
-
def _resumable_file_manager() -> "io.BufferedWriter":
|
223 |
-
with open(incomplete_path, "ab") as f:
|
224 |
-
yield f
|
225 |
-
|
226 |
-
temp_file_manager = _resumable_file_manager
|
227 |
-
if os.path.exists(incomplete_path):
|
228 |
-
resume_size = os.stat(incomplete_path).st_size
|
229 |
-
else:
|
230 |
-
resume_size = 0
|
231 |
-
else:
|
232 |
-
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
|
233 |
-
resume_size = 0
|
234 |
-
|
235 |
-
# Download to temporary file, then copy to cache dir once finished.
|
236 |
-
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
237 |
-
with temp_file_manager() as temp_file:
|
238 |
-
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
|
239 |
-
|
240 |
-
os.replace(temp_file.name, cache_path)
|
241 |
-
|
242 |
-
meta = {"url": url, "etag": etag}
|
243 |
-
meta_path = cache_path + ".json"
|
244 |
-
with open(meta_path, "w") as meta_file:
|
245 |
-
json.dump(meta, meta_file)
|
246 |
-
|
247 |
-
return cache_path
|
248 |
-
|
249 |
-
|
250 |
-
def cached_path(
|
251 |
-
url_or_filename,
|
252 |
-
cache_dir=None,
|
253 |
-
force_download=False,
|
254 |
-
proxies=None,
|
255 |
-
resume_download=False,
|
256 |
-
user_agent: Union[Dict, str, None] = None,
|
257 |
-
extract_compressed_file=False,
|
258 |
-
force_extract=False,
|
259 |
-
use_auth_token: Union[bool, str, None] = None,
|
260 |
-
local_files_only=False,
|
261 |
-
) -> Optional[str]:
|
262 |
-
if cache_dir is None:
|
263 |
-
cache_dir = TRANSFORMERS_CACHE
|
264 |
-
if isinstance(url_or_filename, Path):
|
265 |
-
url_or_filename = str(url_or_filename)
|
266 |
-
if isinstance(cache_dir, Path):
|
267 |
-
cache_dir = str(cache_dir)
|
268 |
-
|
269 |
-
if is_remote_url(url_or_filename):
|
270 |
-
# URL, so get it from the cache (downloading if necessary)
|
271 |
-
output_path = get_from_cache(
|
272 |
-
url_or_filename,
|
273 |
-
cache_dir=cache_dir,
|
274 |
-
force_download=force_download,
|
275 |
-
proxies=proxies,
|
276 |
-
resume_download=resume_download,
|
277 |
-
user_agent=user_agent,
|
278 |
-
use_auth_token=use_auth_token,
|
279 |
-
local_files_only=local_files_only,
|
280 |
-
)
|
281 |
-
elif os.path.exists(url_or_filename):
|
282 |
-
# File, and it exists.
|
283 |
-
output_path = url_or_filename
|
284 |
-
elif urlparse(url_or_filename).scheme == "":
|
285 |
-
# File, but it doesn't exist.
|
286 |
-
raise EnvironmentError("file {} not found".format(url_or_filename))
|
287 |
-
else:
|
288 |
-
# Something unknown
|
289 |
-
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
290 |
-
|
291 |
-
if extract_compressed_file:
|
292 |
-
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
293 |
-
return output_path
|
294 |
-
|
295 |
-
# Path where we extract compressed archives
|
296 |
-
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
297 |
-
output_dir, output_file = os.path.split(output_path)
|
298 |
-
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
299 |
-
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
300 |
-
|
301 |
-
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
302 |
-
return output_path_extracted
|
303 |
-
|
304 |
-
# Prevent parallel extractions
|
305 |
-
lock_path = output_path + ".lock"
|
306 |
-
with FileLock(lock_path):
|
307 |
-
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
308 |
-
os.makedirs(output_path_extracted)
|
309 |
-
if is_zipfile(output_path):
|
310 |
-
with ZipFile(output_path, "r") as zip_file:
|
311 |
-
zip_file.extractall(output_path_extracted)
|
312 |
-
zip_file.close()
|
313 |
-
elif tarfile.is_tarfile(output_path):
|
314 |
-
tar_file = tarfile.open(output_path)
|
315 |
-
tar_file.extractall(output_path_extracted)
|
316 |
-
tar_file.close()
|
317 |
-
else:
|
318 |
-
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
|
319 |
-
|
320 |
-
return output_path_extracted
|
321 |
-
|
322 |
-
return output_path
|
323 |
-
|
324 |
-
|
325 |
-
def get_parameter_dtype(parameter: Union[nn.Module]):
|
326 |
-
try:
|
327 |
-
return next(parameter.parameters()).dtype
|
328 |
-
except StopIteration:
|
329 |
-
# For nn.DataParallel compatibility in PyTorch 1.5
|
330 |
-
|
331 |
-
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
332 |
-
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
333 |
-
return tuples
|
334 |
-
|
335 |
-
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
336 |
-
first_tuple = next(gen)
|
337 |
-
return first_tuple[1].dtype
|
338 |
-
|
339 |
-
|
340 |
-
def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor:
|
341 |
-
# attention_mask [batch_size, seq_length]
|
342 |
-
assert attention_mask.dim() == 2
|
343 |
-
# [batch_size, 1, 1, seq_length] for multi-head attention
|
344 |
-
extended_attention_mask = attention_mask[:, None, None, :]
|
345 |
-
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
|
346 |
-
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
347 |
-
return extended_attention_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|