Commit
·
7587354
1
Parent(s):
51eaae4
contrastive commit 1
Browse files- base_bert.py +240 -240
- classifier.py +67 -61
- data/{sts-test-student.csv → nli-dev.parquet} +2 -2
- data/{sts-train.csv → nli-test.parquet} +2 -2
- data/{quora-test-student.csv → nli-train.parquet} +2 -2
- data/{sts-dev.csv → stsb-dev.parquet} +2 -2
- data/{quora-dev.csv → stsb-test.parquet} +2 -2
- data/stsb-train.parquet +3 -0
- data/{quora-train.csv → twitter-unsup.csv} +2 -2
- datasets.py +0 -272
- justfile +1 -0
- multitask_classifier.py +0 -340
- prompt +3 -0
- trainings/last-layer-w-dropout.txt +4 -4
- unsup_simcse.py +204 -0
base_bert.py
CHANGED
@@ -5,244 +5,244 @@ from utils import *
|
|
5 |
|
6 |
|
7 |
class BertPreTrainedModel(nn.Module):
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
else:
|
69 |
-
model_kwargs = kwargs
|
70 |
-
|
71 |
-
# Load model
|
72 |
-
if pretrained_model_name_or_path is not None:
|
73 |
-
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
74 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
75 |
-
# Load from a PyTorch checkpoint
|
76 |
-
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
77 |
-
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
78 |
-
archive_file = pretrained_model_name_or_path
|
79 |
-
else:
|
80 |
-
archive_file = hf_bucket_url(
|
81 |
-
pretrained_model_name_or_path,
|
82 |
-
filename=WEIGHTS_NAME,
|
83 |
-
revision=revision,
|
84 |
-
mirror=mirror,
|
85 |
-
)
|
86 |
-
try:
|
87 |
-
# Load from URL or cache if already cached
|
88 |
-
resolved_archive_file = cached_path(
|
89 |
-
archive_file,
|
90 |
-
cache_dir=cache_dir,
|
91 |
-
force_download=force_download,
|
92 |
-
proxies=proxies,
|
93 |
-
resume_download=resume_download,
|
94 |
-
local_files_only=local_files_only,
|
95 |
-
use_auth_token=use_auth_token,
|
96 |
-
)
|
97 |
-
except EnvironmentError as err:
|
98 |
-
#logger.error(err)
|
99 |
-
msg = (
|
100 |
-
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
101 |
-
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
102 |
-
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
|
103 |
-
)
|
104 |
-
raise EnvironmentError(msg)
|
105 |
-
else:
|
106 |
-
resolved_archive_file = None
|
107 |
-
|
108 |
-
config.name_or_path = pretrained_model_name_or_path
|
109 |
-
|
110 |
-
# Instantiate model.
|
111 |
-
model = cls(config, *model_args, **model_kwargs)
|
112 |
-
|
113 |
-
if state_dict is None:
|
114 |
-
try:
|
115 |
-
state_dict = torch.load(resolved_archive_file, map_location="cpu", weights_only=True)
|
116 |
-
except Exception:
|
117 |
-
raise OSError(
|
118 |
-
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
119 |
-
f"at '{resolved_archive_file}'"
|
120 |
-
)
|
121 |
-
|
122 |
-
missing_keys = []
|
123 |
-
unexpected_keys = []
|
124 |
-
error_msgs = []
|
125 |
-
|
126 |
-
# Convert old format to new format if needed from a PyTorch state_dict
|
127 |
-
old_keys = []
|
128 |
-
new_keys = []
|
129 |
-
m = {'embeddings.word_embeddings': 'word_embedding',
|
130 |
-
'embeddings.position_embeddings': 'pos_embedding',
|
131 |
-
'embeddings.token_type_embeddings': 'tk_type_embedding',
|
132 |
-
'embeddings.LayerNorm': 'embed_layer_norm',
|
133 |
-
'embeddings.dropout': 'embed_dropout',
|
134 |
-
'encoder.layer': 'bert_layers',
|
135 |
-
'pooler.dense': 'pooler_dense',
|
136 |
-
'pooler.activation': 'pooler_af',
|
137 |
-
'attention.self': "self_attention",
|
138 |
-
'attention.output.dense': 'attention_dense',
|
139 |
-
'attention.output.LayerNorm': 'attention_layer_norm',
|
140 |
-
'attention.output.dropout': 'attention_dropout',
|
141 |
-
'intermediate.dense': 'interm_dense',
|
142 |
-
'intermediate.intermediate_act_fn': 'interm_af',
|
143 |
-
'output.dense': 'out_dense',
|
144 |
-
'output.LayerNorm': 'out_layer_norm',
|
145 |
-
'output.dropout': 'out_dropout'}
|
146 |
-
|
147 |
-
for key in state_dict.keys():
|
148 |
-
new_key = None
|
149 |
-
if "gamma" in key:
|
150 |
-
new_key = key.replace("gamma", "weight")
|
151 |
-
if "beta" in key:
|
152 |
-
new_key = key.replace("beta", "bias")
|
153 |
-
for x, y in m.items():
|
154 |
-
if new_key is not None:
|
155 |
-
_key = new_key
|
156 |
else:
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
class BertPreTrainedModel(nn.Module):
|
8 |
+
config_class = BertConfig
|
9 |
+
base_model_prefix = "bert"
|
10 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
11 |
+
_keys_to_ignore_on_load_unexpected = None
|
12 |
+
|
13 |
+
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
|
14 |
+
super().__init__()
|
15 |
+
self.config = config
|
16 |
+
self.name_or_path = config.name_or_path
|
17 |
+
|
18 |
+
def init_weights(self):
|
19 |
+
# Initialize weights
|
20 |
+
self.apply(self._init_weights)
|
21 |
+
|
22 |
+
def _init_weights(self, module):
|
23 |
+
""" Initialize the weights """
|
24 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
25 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
26 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
27 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
28 |
+
elif isinstance(module, nn.LayerNorm):
|
29 |
+
module.bias.data.zero_()
|
30 |
+
module.weight.data.fill_(1.0)
|
31 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
32 |
+
module.bias.data.zero_()
|
33 |
+
|
34 |
+
@property
|
35 |
+
def dtype(self) -> dtype:
|
36 |
+
return get_parameter_dtype(self)
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
40 |
+
config = kwargs.pop("config", None)
|
41 |
+
state_dict = kwargs.pop("state_dict", None)
|
42 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
43 |
+
force_download = kwargs.pop("force_download", False)
|
44 |
+
resume_download = kwargs.pop("resume_download", False)
|
45 |
+
proxies = kwargs.pop("proxies", None)
|
46 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
47 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
48 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
49 |
+
revision = kwargs.pop("revision", None)
|
50 |
+
mirror = kwargs.pop("mirror", None)
|
51 |
+
|
52 |
+
# Load config if we don't provide a configuration
|
53 |
+
if not isinstance(config, PretrainedConfig):
|
54 |
+
config_path = config if config is not None else pretrained_model_name_or_path
|
55 |
+
config, model_kwargs = cls.config_class.from_pretrained(
|
56 |
+
config_path,
|
57 |
+
*model_args,
|
58 |
+
cache_dir=cache_dir,
|
59 |
+
return_unused_kwargs=True,
|
60 |
+
force_download=force_download,
|
61 |
+
resume_download=resume_download,
|
62 |
+
proxies=proxies,
|
63 |
+
local_files_only=local_files_only,
|
64 |
+
use_auth_token=use_auth_token,
|
65 |
+
revision=revision,
|
66 |
+
**kwargs,
|
67 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
else:
|
69 |
+
model_kwargs = kwargs
|
70 |
+
|
71 |
+
# Load model
|
72 |
+
if pretrained_model_name_or_path is not None:
|
73 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
74 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
75 |
+
# Load from a PyTorch checkpoint
|
76 |
+
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
77 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
78 |
+
archive_file = pretrained_model_name_or_path
|
79 |
+
else:
|
80 |
+
archive_file = hf_bucket_url(
|
81 |
+
pretrained_model_name_or_path,
|
82 |
+
filename=WEIGHTS_NAME,
|
83 |
+
revision=revision,
|
84 |
+
mirror=mirror,
|
85 |
+
)
|
86 |
+
try:
|
87 |
+
# Load from URL or cache if already cached
|
88 |
+
resolved_archive_file = cached_path(
|
89 |
+
archive_file,
|
90 |
+
cache_dir=cache_dir,
|
91 |
+
force_download=force_download,
|
92 |
+
proxies=proxies,
|
93 |
+
resume_download=resume_download,
|
94 |
+
local_files_only=local_files_only,
|
95 |
+
use_auth_token=use_auth_token,
|
96 |
+
)
|
97 |
+
except EnvironmentError as err:
|
98 |
+
#logger.error(err)
|
99 |
+
msg = (
|
100 |
+
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
101 |
+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
102 |
+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
|
103 |
+
)
|
104 |
+
raise EnvironmentError(msg)
|
105 |
+
else:
|
106 |
+
resolved_archive_file = None
|
107 |
+
|
108 |
+
config.name_or_path = pretrained_model_name_or_path
|
109 |
+
|
110 |
+
# Instantiate model.
|
111 |
+
model = cls(config, *model_args, **model_kwargs)
|
112 |
+
|
113 |
+
if state_dict is None:
|
114 |
+
try:
|
115 |
+
state_dict = torch.load(resolved_archive_file, map_location="cpu", weights_only=True)
|
116 |
+
except Exception:
|
117 |
+
raise OSError(
|
118 |
+
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
119 |
+
f"at '{resolved_archive_file}'"
|
120 |
+
)
|
121 |
+
|
122 |
+
missing_keys = []
|
123 |
+
unexpected_keys = []
|
124 |
+
error_msgs = []
|
125 |
+
|
126 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
127 |
+
old_keys = []
|
128 |
+
new_keys = []
|
129 |
+
m = {'embeddings.word_embeddings': 'word_embedding',
|
130 |
+
'embeddings.position_embeddings': 'pos_embedding',
|
131 |
+
'embeddings.token_type_embeddings': 'tk_type_embedding',
|
132 |
+
'embeddings.LayerNorm': 'embed_layer_norm',
|
133 |
+
'embeddings.dropout': 'embed_dropout',
|
134 |
+
'encoder.layer': 'bert_layers',
|
135 |
+
'pooler.dense': 'pooler_dense',
|
136 |
+
'pooler.activation': 'pooler_af',
|
137 |
+
'attention.self': "self_attention",
|
138 |
+
'attention.output.dense': 'attention_dense',
|
139 |
+
'attention.output.LayerNorm': 'attention_layer_norm',
|
140 |
+
'attention.output.dropout': 'attention_dropout',
|
141 |
+
'intermediate.dense': 'interm_dense',
|
142 |
+
'intermediate.intermediate_act_fn': 'interm_af',
|
143 |
+
'output.dense': 'out_dense',
|
144 |
+
'output.LayerNorm': 'out_layer_norm',
|
145 |
+
'output.dropout': 'out_dropout'}
|
146 |
+
|
147 |
+
for key in state_dict.keys():
|
148 |
+
new_key = None
|
149 |
+
if "gamma" in key:
|
150 |
+
new_key = key.replace("gamma", "weight")
|
151 |
+
if "beta" in key:
|
152 |
+
new_key = key.replace("beta", "bias")
|
153 |
+
for x, y in m.items():
|
154 |
+
if new_key is not None:
|
155 |
+
_key = new_key
|
156 |
+
else:
|
157 |
+
_key = key
|
158 |
+
if x in key:
|
159 |
+
new_key = _key.replace(x, y)
|
160 |
+
if new_key:
|
161 |
+
old_keys.append(key)
|
162 |
+
new_keys.append(new_key)
|
163 |
+
|
164 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
165 |
+
# print(old_key, new_key)
|
166 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
167 |
+
|
168 |
+
# copy state_dict so _load_from_state_dict can modify it
|
169 |
+
metadata = getattr(state_dict, "_metadata", None)
|
170 |
+
state_dict = state_dict.copy()
|
171 |
+
if metadata is not None:
|
172 |
+
state_dict._metadata = metadata
|
173 |
+
|
174 |
+
your_bert_params = [f"bert.{x[0]}" for x in model.named_parameters()]
|
175 |
+
for k in state_dict:
|
176 |
+
if k not in your_bert_params and not k.startswith("cls."):
|
177 |
+
possible_rename = [x for x in k.split(".")[1:-1] if x in m.values()]
|
178 |
+
raise ValueError(f"{k} cannot be reload to your model, one/some of {possible_rename} we provided have been renamed")
|
179 |
+
|
180 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
181 |
+
# so we need to apply the function recursively.
|
182 |
+
def load(module: nn.Module, prefix=""):
|
183 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
184 |
+
module._load_from_state_dict(
|
185 |
+
state_dict,
|
186 |
+
prefix,
|
187 |
+
local_metadata,
|
188 |
+
True,
|
189 |
+
missing_keys,
|
190 |
+
unexpected_keys,
|
191 |
+
error_msgs,
|
192 |
+
)
|
193 |
+
for name, child in module._modules.items():
|
194 |
+
if child is not None:
|
195 |
+
load(child, prefix + name + ".")
|
196 |
+
|
197 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
198 |
+
start_prefix = ""
|
199 |
+
model_to_load = model
|
200 |
+
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
|
201 |
+
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
202 |
+
start_prefix = cls.base_model_prefix + "."
|
203 |
+
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
204 |
+
model_to_load = getattr(model, cls.base_model_prefix)
|
205 |
+
load(model_to_load, prefix=start_prefix)
|
206 |
+
|
207 |
+
if model.__class__.__name__ != model_to_load.__class__.__name__:
|
208 |
+
base_model_state_dict = model_to_load.state_dict().keys()
|
209 |
+
head_model_state_dict_without_base_prefix = [
|
210 |
+
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
211 |
+
]
|
212 |
+
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
213 |
+
|
214 |
+
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
215 |
+
# the user.
|
216 |
+
if cls._keys_to_ignore_on_load_missing is not None:
|
217 |
+
for pat in cls._keys_to_ignore_on_load_missing:
|
218 |
+
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
219 |
+
|
220 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
221 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
222 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
223 |
+
|
224 |
+
if len(error_msgs) > 0:
|
225 |
+
raise RuntimeError(
|
226 |
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
227 |
+
model.__class__.__name__, "\n\t".join(error_msgs)
|
228 |
+
)
|
229 |
+
)
|
230 |
+
|
231 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
232 |
+
model.eval()
|
233 |
+
|
234 |
+
if output_loading_info:
|
235 |
+
loading_info = {
|
236 |
+
"missing_keys": missing_keys,
|
237 |
+
"unexpected_keys": unexpected_keys,
|
238 |
+
"error_msgs": error_msgs,
|
239 |
+
}
|
240 |
+
return model, loading_info
|
241 |
+
|
242 |
+
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
|
243 |
+
import torch_xla.core.xla_model as xm
|
244 |
+
|
245 |
+
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
246 |
+
model.to(xm.xla_device())
|
247 |
+
|
248 |
+
return model
|
classifier.py
CHANGED
@@ -3,6 +3,7 @@ from types import SimpleNamespace
|
|
3 |
import csv
|
4 |
|
5 |
import torch
|
|
|
6 |
import torch.nn.functional as F
|
7 |
from torch.utils.data import Dataset, DataLoader
|
8 |
from sklearn.metrics import f1_score, accuracy_score
|
@@ -10,7 +11,6 @@ from sklearn.metrics import f1_score, accuracy_score
|
|
10 |
from tokenizer import BertTokenizer
|
11 |
from bert import BertModel
|
12 |
from optimizer import AdamW
|
13 |
-
from tqdm import tqdm
|
14 |
|
15 |
|
16 |
TQDM_DISABLE=False
|
@@ -34,10 +34,10 @@ class BertSentimentClassifier(torch.nn.Module):
|
|
34 |
In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").
|
35 |
Thus, your forward() should return one logit for each of the 5 classes.
|
36 |
'''
|
37 |
-
def __init__(self, config):
|
38 |
super(BertSentimentClassifier, self).__init__()
|
39 |
self.num_labels = config.num_labels
|
40 |
-
self.bert: BertModel = BertModel.from_pretrained('bert-base-uncased')
|
41 |
|
42 |
# Pretrain mode does not require updating BERT paramters.
|
43 |
assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
|
@@ -59,26 +59,21 @@ class BertSentimentClassifier(torch.nn.Module):
|
|
59 |
# the training loop currently uses F.cross_entropy as the loss function.
|
60 |
|
61 |
# Get the embedding for each input token.
|
62 |
-
|
63 |
-
|
64 |
-
# Feed to a transformer (BERT layers).
|
65 |
-
sequence_output = self.bert.encode(embedding_output, attention_mask=attention_mask)
|
66 |
-
|
67 |
-
# The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).
|
68 |
-
cls_token_output = sequence_output[:, 0, :] # The first token is [CLS]
|
69 |
|
70 |
# Pass the [CLS] token representation through the classifier.
|
71 |
-
logits = self.classifier(self.dropout(
|
72 |
|
73 |
return logits
|
74 |
|
75 |
|
|
|
76 |
|
77 |
class SentimentDataset(Dataset):
|
78 |
def __init__(self, dataset, args):
|
79 |
self.dataset = dataset
|
80 |
self.p = args
|
81 |
-
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
82 |
|
83 |
def __len__(self):
|
84 |
return len(self.dataset)
|
@@ -91,7 +86,7 @@ class SentimentDataset(Dataset):
|
|
91 |
labels = [x[1] for x in data]
|
92 |
sent_ids = [x[2] for x in data]
|
93 |
|
94 |
-
encoding =
|
95 |
token_ids = torch.LongTensor(encoding['input_ids'])
|
96 |
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
97 |
labels = torch.LongTensor(labels)
|
@@ -99,15 +94,15 @@ class SentimentDataset(Dataset):
|
|
99 |
return token_ids, attention_mask, labels, sents, sent_ids
|
100 |
|
101 |
def collate_fn(self, all_data):
|
102 |
-
token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
|
103 |
|
104 |
batched_data = {
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
|
112 |
return batched_data
|
113 |
|
@@ -116,7 +111,6 @@ class SentimentTestDataset(Dataset):
|
|
116 |
def __init__(self, dataset, args):
|
117 |
self.dataset = dataset
|
118 |
self.p = args
|
119 |
-
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
120 |
|
121 |
def __len__(self):
|
122 |
return len(self.dataset)
|
@@ -128,7 +122,7 @@ class SentimentTestDataset(Dataset):
|
|
128 |
sents = [x[0] for x in data]
|
129 |
sent_ids = [x[1] for x in data]
|
130 |
|
131 |
-
encoding =
|
132 |
token_ids = torch.LongTensor(encoding['input_ids'])
|
133 |
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
134 |
|
@@ -138,34 +132,31 @@ class SentimentTestDataset(Dataset):
|
|
138 |
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
|
139 |
|
140 |
batched_data = {
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
|
147 |
return batched_data
|
148 |
|
149 |
|
150 |
# Load the data: a list of (sentence, label).
|
151 |
def load_data(filename, flag='train'):
|
152 |
-
num_labels =
|
153 |
data = []
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
sent = record['sentence'].lower().strip()
|
158 |
sent_id = record['id'].lower().strip()
|
159 |
data.append((sent,sent_id))
|
160 |
-
|
161 |
-
with open(filename, 'r') as fp:
|
162 |
-
for record in csv.DictReader(fp,delimiter = '\t'):
|
163 |
sent = record['sentence'].lower().strip()
|
164 |
sent_id = record['id'].lower().strip()
|
165 |
label = int(record['sentiment'].strip())
|
166 |
-
|
167 |
-
|
168 |
-
data.append((sent, label,sent_id))
|
169 |
print(f"load {len(data)} data from {filename}")
|
170 |
|
171 |
if flag == 'train':
|
@@ -253,9 +244,9 @@ def train(args):
|
|
253 |
dev_dataset = SentimentDataset(dev_data, args)
|
254 |
|
255 |
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
|
256 |
-
collate_fn=train_dataset.collate_fn)
|
257 |
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
|
258 |
-
collate_fn=dev_dataset.collate_fn)
|
259 |
|
260 |
# Init model.
|
261 |
config = {'hidden_dropout_prob': args.hidden_dropout_prob,
|
@@ -311,7 +302,7 @@ def train(args):
|
|
311 |
def test(args):
|
312 |
with torch.no_grad():
|
313 |
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
314 |
-
saved = torch.load(args.filepath)
|
315 |
config = saved['model_config']
|
316 |
model = BertSentimentClassifier(config)
|
317 |
model.load_state_dict(saved['model'])
|
@@ -320,38 +311,44 @@ def test(args):
|
|
320 |
|
321 |
dev_data = load_data(args.dev, 'valid')
|
322 |
dev_dataset = SentimentDataset(dev_data, args)
|
323 |
-
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
|
|
|
324 |
|
325 |
-
test_data = load_data(args.test, 'test')
|
326 |
-
test_dataset = SentimentTestDataset(test_data, args)
|
327 |
-
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn)
|
328 |
-
|
329 |
dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
|
330 |
print('DONE DEV')
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
|
345 |
def get_args():
|
346 |
parser = argparse.ArgumentParser()
|
347 |
parser.add_argument("--seed", type=int, default=11711)
|
|
|
348 |
parser.add_argument("--epochs", type=int, default=10)
|
349 |
parser.add_argument("--fine-tune-mode", type=str,
|
350 |
help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
|
351 |
choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
|
352 |
parser.add_argument("--use_gpu", action='store_true')
|
353 |
|
354 |
-
parser.add_argument("--
|
|
|
355 |
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
356 |
parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
|
357 |
default=1e-3)
|
@@ -360,17 +357,21 @@ def get_args():
|
|
360 |
return args
|
361 |
|
362 |
|
363 |
-
|
364 |
args = get_args()
|
365 |
seed_everything(args.seed)
|
|
|
|
|
|
|
366 |
|
367 |
print('Training Sentiment Classifier on SST...')
|
368 |
config = SimpleNamespace(
|
369 |
filepath='sst-classifier.pt',
|
370 |
lr=args.lr,
|
|
|
371 |
use_gpu=args.use_gpu,
|
372 |
epochs=args.epochs,
|
373 |
-
batch_size=args.
|
374 |
hidden_dropout_prob=args.hidden_dropout_prob,
|
375 |
train='data/ids-sst-train.csv',
|
376 |
dev='data/ids-sst-dev.csv',
|
@@ -389,9 +390,10 @@ if __name__ == "__main__":
|
|
389 |
config = SimpleNamespace(
|
390 |
filepath='cfimdb-classifier.pt',
|
391 |
lr=args.lr,
|
|
|
392 |
use_gpu=args.use_gpu,
|
393 |
epochs=args.epochs,
|
394 |
-
batch_size=
|
395 |
hidden_dropout_prob=args.hidden_dropout_prob,
|
396 |
train='data/ids-cfimdb-train.csv',
|
397 |
dev='data/ids-cfimdb-dev.csv',
|
@@ -405,3 +407,7 @@ if __name__ == "__main__":
|
|
405 |
|
406 |
print('Evaluating on cfimdb...')
|
407 |
test(config)
|
|
|
|
|
|
|
|
|
|
3 |
import csv
|
4 |
|
5 |
import torch
|
6 |
+
from tqdm import tqdm
|
7 |
import torch.nn.functional as F
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
from sklearn.metrics import f1_score, accuracy_score
|
|
|
11 |
from tokenizer import BertTokenizer
|
12 |
from bert import BertModel
|
13 |
from optimizer import AdamW
|
|
|
14 |
|
15 |
|
16 |
TQDM_DISABLE=False
|
|
|
34 |
In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").
|
35 |
Thus, your forward() should return one logit for each of the 5 classes.
|
36 |
'''
|
37 |
+
def __init__(self, config, bert_model = None):
|
38 |
super(BertSentimentClassifier, self).__init__()
|
39 |
self.num_labels = config.num_labels
|
40 |
+
self.bert: BertModel = bert_model or BertModel.from_pretrained('bert-base-uncased')
|
41 |
|
42 |
# Pretrain mode does not require updating BERT paramters.
|
43 |
assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
|
|
|
59 |
# the training loop currently uses F.cross_entropy as the loss function.
|
60 |
|
61 |
# Get the embedding for each input token.
|
62 |
+
outputs = self.bert(input_ids, attention_mask)
|
63 |
+
pooler_output = outputs['pooler_output']
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
# Pass the [CLS] token representation through the classifier.
|
66 |
+
logits = self.classifier(self.dropout(pooler_output))
|
67 |
|
68 |
return logits
|
69 |
|
70 |
|
71 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
72 |
|
73 |
class SentimentDataset(Dataset):
|
74 |
def __init__(self, dataset, args):
|
75 |
self.dataset = dataset
|
76 |
self.p = args
|
|
|
77 |
|
78 |
def __len__(self):
|
79 |
return len(self.dataset)
|
|
|
86 |
labels = [x[1] for x in data]
|
87 |
sent_ids = [x[2] for x in data]
|
88 |
|
89 |
+
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
90 |
token_ids = torch.LongTensor(encoding['input_ids'])
|
91 |
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
92 |
labels = torch.LongTensor(labels)
|
|
|
94 |
return token_ids, attention_mask, labels, sents, sent_ids
|
95 |
|
96 |
def collate_fn(self, all_data):
|
97 |
+
token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data)
|
98 |
|
99 |
batched_data = {
|
100 |
+
'token_ids': token_ids,
|
101 |
+
'attention_mask': attention_mask,
|
102 |
+
'labels': labels,
|
103 |
+
'sents': sents,
|
104 |
+
'sent_ids': sent_ids
|
105 |
+
}
|
106 |
|
107 |
return batched_data
|
108 |
|
|
|
111 |
def __init__(self, dataset, args):
|
112 |
self.dataset = dataset
|
113 |
self.p = args
|
|
|
114 |
|
115 |
def __len__(self):
|
116 |
return len(self.dataset)
|
|
|
122 |
sents = [x[0] for x in data]
|
123 |
sent_ids = [x[1] for x in data]
|
124 |
|
125 |
+
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
126 |
token_ids = torch.LongTensor(encoding['input_ids'])
|
127 |
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
128 |
|
|
|
132 |
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
|
133 |
|
134 |
batched_data = {
|
135 |
+
'token_ids': token_ids,
|
136 |
+
'attention_mask': attention_mask,
|
137 |
+
'sents': sents,
|
138 |
+
'sent_ids': sent_ids
|
139 |
+
}
|
140 |
|
141 |
return batched_data
|
142 |
|
143 |
|
144 |
# Load the data: a list of (sentence, label).
|
145 |
def load_data(filename, flag='train'):
|
146 |
+
num_labels = set()
|
147 |
data = []
|
148 |
+
with open(filename, 'r') as fp:
|
149 |
+
for record in csv.DictReader(fp, delimiter = '\t'):
|
150 |
+
if flag == 'test':
|
151 |
sent = record['sentence'].lower().strip()
|
152 |
sent_id = record['id'].lower().strip()
|
153 |
data.append((sent,sent_id))
|
154 |
+
else:
|
|
|
|
|
155 |
sent = record['sentence'].lower().strip()
|
156 |
sent_id = record['id'].lower().strip()
|
157 |
label = int(record['sentiment'].strip())
|
158 |
+
num_labels.add(label)
|
159 |
+
data.append((sent, label, sent_id))
|
|
|
160 |
print(f"load {len(data)} data from {filename}")
|
161 |
|
162 |
if flag == 'train':
|
|
|
244 |
dev_dataset = SentimentDataset(dev_data, args)
|
245 |
|
246 |
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
|
247 |
+
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
|
248 |
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
|
249 |
+
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
|
250 |
|
251 |
# Init model.
|
252 |
config = {'hidden_dropout_prob': args.hidden_dropout_prob,
|
|
|
302 |
def test(args):
|
303 |
with torch.no_grad():
|
304 |
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
305 |
+
saved = torch.load(args.filepath, weights_only=False)
|
306 |
config = saved['model_config']
|
307 |
model = BertSentimentClassifier(config)
|
308 |
model.load_state_dict(saved['model'])
|
|
|
311 |
|
312 |
dev_data = load_data(args.dev, 'valid')
|
313 |
dev_dataset = SentimentDataset(dev_data, args)
|
314 |
+
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
|
315 |
+
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
|
316 |
|
|
|
|
|
|
|
|
|
317 |
dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
|
318 |
print('DONE DEV')
|
319 |
+
print(f"dev acc :: {dev_acc :.3f}")
|
320 |
+
|
321 |
+
# ---- SKIP RUNNING ON TEST DATASET ---- #
|
322 |
+
# test_data = load_data(args.test, 'test')
|
323 |
+
# test_dataset = SentimentTestDataset(test_data, args)
|
324 |
+
# test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size,
|
325 |
+
# num_workers=args.num_cpu_cores, collate_fn=test_dataset.collate_fn)
|
326 |
+
# test_pred, test_sents, test_sent_ids = model_test_eval(test_dataloader, model, device)
|
327 |
+
# print('DONE TEST')
|
328 |
+
|
329 |
+
# ---- SKIP SAVING PREDICTIONS ----
|
330 |
+
# with open(args.dev_out, "w+") as f:
|
331 |
+
# f.write(f"id \t Predicted_Sentiment \n")
|
332 |
+
# for p, s in zip(dev_sent_ids,dev_pred):
|
333 |
+
# f.write(f"{p} , {s} \n")
|
334 |
+
# with open(args.test_out, "w+") as f:
|
335 |
+
# f.write(f"id \t Predicted_Sentiment \n")
|
336 |
+
# for p, s in zip(test_sent_ids,test_pred ):
|
337 |
+
# f.write(f"{p} , {s} \n")
|
338 |
|
339 |
|
340 |
def get_args():
|
341 |
parser = argparse.ArgumentParser()
|
342 |
parser.add_argument("--seed", type=int, default=11711)
|
343 |
+
parser.add_argument("--num-cpu-cores", type=int, default=4)
|
344 |
parser.add_argument("--epochs", type=int, default=10)
|
345 |
parser.add_argument("--fine-tune-mode", type=str,
|
346 |
help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
|
347 |
choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
|
348 |
parser.add_argument("--use_gpu", action='store_true')
|
349 |
|
350 |
+
parser.add_argument("--batch_size_sst", help='64 can fit a 12GB GPU', type=int, default=8)
|
351 |
+
parser.add_argument("--batch_size_cfimdb", help='8 can fit a 12GB GPU', type=int, default=8)
|
352 |
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
353 |
parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
|
354 |
default=1e-3)
|
|
|
357 |
return args
|
358 |
|
359 |
|
360 |
+
def main():
|
361 |
args = get_args()
|
362 |
seed_everything(args.seed)
|
363 |
+
torch.set_num_threads(args.num_cpu_cores)
|
364 |
+
|
365 |
+
print(torch.get_num_threads())
|
366 |
|
367 |
print('Training Sentiment Classifier on SST...')
|
368 |
config = SimpleNamespace(
|
369 |
filepath='sst-classifier.pt',
|
370 |
lr=args.lr,
|
371 |
+
num_cpu_cores=args.num_cpu_cores,
|
372 |
use_gpu=args.use_gpu,
|
373 |
epochs=args.epochs,
|
374 |
+
batch_size=args.batch_size_sst,
|
375 |
hidden_dropout_prob=args.hidden_dropout_prob,
|
376 |
train='data/ids-sst-train.csv',
|
377 |
dev='data/ids-sst-dev.csv',
|
|
|
390 |
config = SimpleNamespace(
|
391 |
filepath='cfimdb-classifier.pt',
|
392 |
lr=args.lr,
|
393 |
+
num_cpu_cores=args.num_cpu_cores,
|
394 |
use_gpu=args.use_gpu,
|
395 |
epochs=args.epochs,
|
396 |
+
batch_size=args.batch_size_cfimdb,
|
397 |
hidden_dropout_prob=args.hidden_dropout_prob,
|
398 |
train='data/ids-cfimdb-train.csv',
|
399 |
dev='data/ids-cfimdb-dev.csv',
|
|
|
407 |
|
408 |
print('Evaluating on cfimdb...')
|
409 |
test(config)
|
410 |
+
|
411 |
+
|
412 |
+
if __name__ == "__main__":
|
413 |
+
main()
|
data/{sts-test-student.csv → nli-dev.parquet}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c267496435885e724abc71e53669fae59db875bfa13389eab8f9b0b2dfb2b32e
|
3 |
+
size 782233
|
data/{sts-train.csv → nli-test.parquet}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:01688df43ae4c019a86144a0d2351146b124688a55f285071cccd156225a5fdf
|
3 |
+
size 810423
|
data/{quora-test-student.csv → nli-train.parquet}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9aeca80b1bda983ee316f854ebc37af8341877fb932dd6a2c6aba978ad112a5
|
3 |
+
size 38396324
|
data/{sts-dev.csv → stsb-dev.parquet}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c6e0e9881f1b398abe3e439a482f4686305c3784568c462f6bba58bdff03b0a
|
3 |
+
size 142187
|
data/{quora-dev.csv → stsb-test.parquet}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8acbc291c50977d8655934952956016c3e049c2fe04f8a6c454c1bf6acc42ca1
|
3 |
+
size 108100
|
data/stsb-train.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eae324ff1eac2d0ba769851736eb7232eda64f370a16eb20e74a2c5f8f5fafe0
|
3 |
+
size 470612
|
data/{quora-train.csv → twitter-unsup.csv}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5a7af1ec5fc749ec8e5ea13c574aeb5c06254aa1c081e3421868079d5356b3f4
|
3 |
+
size 20895533
|
datasets.py
DELETED
@@ -1,272 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
|
3 |
-
'''
|
4 |
-
This module contains our Dataset classes and functions that load the three datasets
|
5 |
-
for training and evaluating multitask BERT.
|
6 |
-
|
7 |
-
Feel free to edit code in this file if you wish to modify the way in which the data
|
8 |
-
examples are preprocessed.
|
9 |
-
'''
|
10 |
-
|
11 |
-
import csv
|
12 |
-
|
13 |
-
import torch
|
14 |
-
from torch.utils.data import Dataset
|
15 |
-
from tokenizer import BertTokenizer
|
16 |
-
|
17 |
-
|
18 |
-
def preprocess_string(s):
|
19 |
-
return ' '.join(s.lower()
|
20 |
-
.replace('.', ' .')
|
21 |
-
.replace('?', ' ?')
|
22 |
-
.replace(',', ' ,')
|
23 |
-
.replace('\'', ' \'')
|
24 |
-
.split())
|
25 |
-
|
26 |
-
|
27 |
-
class SentenceClassificationDataset(Dataset):
|
28 |
-
def __init__(self, dataset, args):
|
29 |
-
self.dataset = dataset
|
30 |
-
self.p = args
|
31 |
-
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
32 |
-
|
33 |
-
def __len__(self):
|
34 |
-
return len(self.dataset)
|
35 |
-
|
36 |
-
def __getitem__(self, idx):
|
37 |
-
return self.dataset[idx]
|
38 |
-
|
39 |
-
def pad_data(self, data):
|
40 |
-
|
41 |
-
sents = [x[0] for x in data]
|
42 |
-
labels = [x[1] for x in data]
|
43 |
-
sent_ids = [x[2] for x in data]
|
44 |
-
|
45 |
-
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
46 |
-
token_ids = torch.LongTensor(encoding['input_ids'])
|
47 |
-
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
48 |
-
labels = torch.LongTensor(labels)
|
49 |
-
|
50 |
-
return token_ids, attention_mask, labels, sents, sent_ids
|
51 |
-
|
52 |
-
def collate_fn(self, all_data):
|
53 |
-
token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
|
54 |
-
|
55 |
-
batched_data = {
|
56 |
-
'token_ids': token_ids,
|
57 |
-
'attention_mask': attention_mask,
|
58 |
-
'labels': labels,
|
59 |
-
'sents': sents,
|
60 |
-
'sent_ids': sent_ids
|
61 |
-
}
|
62 |
-
|
63 |
-
return batched_data
|
64 |
-
|
65 |
-
|
66 |
-
# Unlike SentenceClassificationDataset, we do not load labels in SentenceClassificationTestDataset.
|
67 |
-
class SentenceClassificationTestDataset(Dataset):
|
68 |
-
def __init__(self, dataset, args):
|
69 |
-
self.dataset = dataset
|
70 |
-
self.p = args
|
71 |
-
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
72 |
-
|
73 |
-
def __len__(self):
|
74 |
-
return len(self.dataset)
|
75 |
-
|
76 |
-
def __getitem__(self, idx):
|
77 |
-
return self.dataset[idx]
|
78 |
-
|
79 |
-
def pad_data(self, data):
|
80 |
-
sents = [x[0] for x in data]
|
81 |
-
sent_ids = [x[1] for x in data]
|
82 |
-
|
83 |
-
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
84 |
-
token_ids = torch.LongTensor(encoding['input_ids'])
|
85 |
-
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
86 |
-
|
87 |
-
return token_ids, attention_mask, sents, sent_ids
|
88 |
-
|
89 |
-
def collate_fn(self, all_data):
|
90 |
-
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
|
91 |
-
|
92 |
-
batched_data = {
|
93 |
-
'token_ids': token_ids,
|
94 |
-
'attention_mask': attention_mask,
|
95 |
-
'sents': sents,
|
96 |
-
'sent_ids': sent_ids
|
97 |
-
}
|
98 |
-
|
99 |
-
return batched_data
|
100 |
-
|
101 |
-
|
102 |
-
class SentencePairDataset(Dataset):
|
103 |
-
def __init__(self, dataset, args, isRegression=False):
|
104 |
-
self.dataset = dataset
|
105 |
-
self.p = args
|
106 |
-
self.isRegression = isRegression
|
107 |
-
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
108 |
-
|
109 |
-
def __len__(self):
|
110 |
-
return len(self.dataset)
|
111 |
-
|
112 |
-
def __getitem__(self, idx):
|
113 |
-
return self.dataset[idx]
|
114 |
-
|
115 |
-
def pad_data(self, data):
|
116 |
-
sent1 = [x[0] for x in data]
|
117 |
-
sent2 = [x[1] for x in data]
|
118 |
-
labels = [x[2] for x in data]
|
119 |
-
sent_ids = [x[3] for x in data]
|
120 |
-
|
121 |
-
encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
|
122 |
-
encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)
|
123 |
-
|
124 |
-
token_ids = torch.LongTensor(encoding1['input_ids'])
|
125 |
-
attention_mask = torch.LongTensor(encoding1['attention_mask'])
|
126 |
-
token_type_ids = torch.LongTensor(encoding1['token_type_ids'])
|
127 |
-
|
128 |
-
token_ids2 = torch.LongTensor(encoding2['input_ids'])
|
129 |
-
attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
|
130 |
-
token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
|
131 |
-
if self.isRegression:
|
132 |
-
labels = torch.DoubleTensor(labels)
|
133 |
-
else:
|
134 |
-
labels = torch.LongTensor(labels)
|
135 |
-
|
136 |
-
return (token_ids, token_type_ids, attention_mask,
|
137 |
-
token_ids2, token_type_ids2, attention_mask2,
|
138 |
-
labels,sent_ids)
|
139 |
-
|
140 |
-
def collate_fn(self, all_data):
|
141 |
-
(token_ids, token_type_ids, attention_mask,
|
142 |
-
token_ids2, token_type_ids2, attention_mask2,
|
143 |
-
labels, sent_ids) = self.pad_data(all_data)
|
144 |
-
|
145 |
-
batched_data = {
|
146 |
-
'token_ids_1': token_ids,
|
147 |
-
'token_type_ids_1': token_type_ids,
|
148 |
-
'attention_mask_1': attention_mask,
|
149 |
-
'token_ids_2': token_ids2,
|
150 |
-
'token_type_ids_2': token_type_ids2,
|
151 |
-
'attention_mask_2': attention_mask2,
|
152 |
-
'labels': labels,
|
153 |
-
'sent_ids': sent_ids
|
154 |
-
}
|
155 |
-
|
156 |
-
return batched_data
|
157 |
-
|
158 |
-
|
159 |
-
# Unlike SentencePairDataset, we do not load labels in SentencePairTestDataset.
|
160 |
-
class SentencePairTestDataset(Dataset):
|
161 |
-
def __init__(self, dataset, args):
|
162 |
-
self.dataset = dataset
|
163 |
-
self.p = args
|
164 |
-
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
165 |
-
|
166 |
-
def __len__(self):
|
167 |
-
return len(self.dataset)
|
168 |
-
|
169 |
-
def __getitem__(self, idx):
|
170 |
-
return self.dataset[idx]
|
171 |
-
|
172 |
-
def pad_data(self, data):
|
173 |
-
sent1 = [x[0] for x in data]
|
174 |
-
sent2 = [x[1] for x in data]
|
175 |
-
sent_ids = [x[2] for x in data]
|
176 |
-
|
177 |
-
encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
|
178 |
-
encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)
|
179 |
-
|
180 |
-
token_ids = torch.LongTensor(encoding1['input_ids'])
|
181 |
-
attention_mask = torch.LongTensor(encoding1['attention_mask'])
|
182 |
-
token_type_ids = torch.LongTensor(encoding1['token_type_ids'])
|
183 |
-
|
184 |
-
token_ids2 = torch.LongTensor(encoding2['input_ids'])
|
185 |
-
attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
|
186 |
-
token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
|
187 |
-
|
188 |
-
|
189 |
-
return (token_ids, token_type_ids, attention_mask,
|
190 |
-
token_ids2, token_type_ids2, attention_mask2,
|
191 |
-
sent_ids)
|
192 |
-
|
193 |
-
def collate_fn(self, all_data):
|
194 |
-
(token_ids, token_type_ids, attention_mask,
|
195 |
-
token_ids2, token_type_ids2, attention_mask2,
|
196 |
-
sent_ids) = self.pad_data(all_data)
|
197 |
-
|
198 |
-
batched_data = {
|
199 |
-
'token_ids_1': token_ids,
|
200 |
-
'token_type_ids_1': token_type_ids,
|
201 |
-
'attention_mask_1': attention_mask,
|
202 |
-
'token_ids_2': token_ids2,
|
203 |
-
'token_type_ids_2': token_type_ids2,
|
204 |
-
'attention_mask_2': attention_mask2,
|
205 |
-
'sent_ids': sent_ids
|
206 |
-
}
|
207 |
-
|
208 |
-
return batched_data
|
209 |
-
|
210 |
-
|
211 |
-
def load_multitask_data(sentiment_filename,paraphrase_filename,similarity_filename,split='train'):
|
212 |
-
sentiment_data = []
|
213 |
-
num_labels = {}
|
214 |
-
if split == 'test':
|
215 |
-
with open(sentiment_filename, 'r') as fp:
|
216 |
-
for record in csv.DictReader(fp,delimiter = '\t'):
|
217 |
-
sent = record['sentence'].lower().strip()
|
218 |
-
sent_id = record['id'].lower().strip()
|
219 |
-
sentiment_data.append((sent,sent_id))
|
220 |
-
else:
|
221 |
-
with open(sentiment_filename, 'r') as fp:
|
222 |
-
for record in csv.DictReader(fp,delimiter = '\t'):
|
223 |
-
sent = record['sentence'].lower().strip()
|
224 |
-
sent_id = record['id'].lower().strip()
|
225 |
-
label = int(record['sentiment'].strip())
|
226 |
-
if label not in num_labels:
|
227 |
-
num_labels[label] = len(num_labels)
|
228 |
-
sentiment_data.append((sent, label,sent_id))
|
229 |
-
|
230 |
-
print(f"Loaded {len(sentiment_data)} {split} examples from {sentiment_filename}")
|
231 |
-
|
232 |
-
paraphrase_data = []
|
233 |
-
if split == 'test':
|
234 |
-
with open(paraphrase_filename, 'r') as fp:
|
235 |
-
for record in csv.DictReader(fp,delimiter = '\t'):
|
236 |
-
sent_id = record['id'].lower().strip()
|
237 |
-
paraphrase_data.append((preprocess_string(record['sentence1']),
|
238 |
-
preprocess_string(record['sentence2']),
|
239 |
-
sent_id))
|
240 |
-
|
241 |
-
else:
|
242 |
-
with open(paraphrase_filename, 'r') as fp:
|
243 |
-
for record in csv.DictReader(fp,delimiter = '\t'):
|
244 |
-
try:
|
245 |
-
sent_id = record['id'].lower().strip()
|
246 |
-
paraphrase_data.append((preprocess_string(record['sentence1']),
|
247 |
-
preprocess_string(record['sentence2']),
|
248 |
-
int(float(record['is_duplicate'])),sent_id))
|
249 |
-
except:
|
250 |
-
pass
|
251 |
-
|
252 |
-
print(f"Loaded {len(paraphrase_data)} {split} examples from {paraphrase_filename}")
|
253 |
-
|
254 |
-
similarity_data = []
|
255 |
-
if split == 'test':
|
256 |
-
with open(similarity_filename, 'r') as fp:
|
257 |
-
for record in csv.DictReader(fp,delimiter = '\t'):
|
258 |
-
sent_id = record['id'].lower().strip()
|
259 |
-
similarity_data.append((preprocess_string(record['sentence1']),
|
260 |
-
preprocess_string(record['sentence2'])
|
261 |
-
,sent_id))
|
262 |
-
else:
|
263 |
-
with open(similarity_filename, 'r') as fp:
|
264 |
-
for record in csv.DictReader(fp,delimiter = '\t'):
|
265 |
-
sent_id = record['id'].lower().strip()
|
266 |
-
similarity_data.append((preprocess_string(record['sentence1']),
|
267 |
-
preprocess_string(record['sentence2']),
|
268 |
-
float(record['similarity']),sent_id))
|
269 |
-
|
270 |
-
print(f"Loaded {len(similarity_data)} {split} examples from {similarity_filename}")
|
271 |
-
|
272 |
-
return sentiment_data, num_labels, paraphrase_data, similarity_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
justfile
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python classifier.py --num-cpu-cores 8 --batch_size_sst 64 --batch_size_cfimdb 8
|
multitask_classifier.py
DELETED
@@ -1,340 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
Multitask BERT class, starter training code, evaluation, and test code.
|
3 |
-
|
4 |
-
Of note are:
|
5 |
-
* class MultitaskBERT: Your implementation of multitask BERT.
|
6 |
-
* function train_multitask: Training procedure for MultitaskBERT. Starter code
|
7 |
-
copies training procedure from `classifier.py` (single-task SST).
|
8 |
-
* function test_multitask: Test procedure for MultitaskBERT. This function generates
|
9 |
-
the required files for submission.
|
10 |
-
|
11 |
-
Running `python multitask_classifier.py` trains and tests your MultitaskBERT and
|
12 |
-
writes all required submission files.
|
13 |
-
'''
|
14 |
-
|
15 |
-
import random, numpy as np, argparse
|
16 |
-
from types import SimpleNamespace
|
17 |
-
|
18 |
-
import torch
|
19 |
-
from torch import nn
|
20 |
-
import torch.nn.functional as F
|
21 |
-
from torch.utils.data import DataLoader
|
22 |
-
|
23 |
-
from bert import BertModel
|
24 |
-
from optimizer import AdamW
|
25 |
-
from tqdm import tqdm
|
26 |
-
|
27 |
-
from datasets import (
|
28 |
-
SentenceClassificationDataset,
|
29 |
-
SentenceClassificationTestDataset,
|
30 |
-
SentencePairDataset,
|
31 |
-
SentencePairTestDataset,
|
32 |
-
load_multitask_data
|
33 |
-
)
|
34 |
-
|
35 |
-
from evaluation import model_eval_sst, model_eval_multitask, model_eval_test_multitask
|
36 |
-
|
37 |
-
|
38 |
-
TQDM_DISABLE=False
|
39 |
-
|
40 |
-
|
41 |
-
# Fix the random seed.
|
42 |
-
def seed_everything(seed=11711):
|
43 |
-
random.seed(seed)
|
44 |
-
np.random.seed(seed)
|
45 |
-
torch.manual_seed(seed)
|
46 |
-
torch.cuda.manual_seed(seed)
|
47 |
-
torch.cuda.manual_seed_all(seed)
|
48 |
-
torch.backends.cudnn.benchmark = False
|
49 |
-
torch.backends.cudnn.deterministic = True
|
50 |
-
|
51 |
-
|
52 |
-
BERT_HIDDEN_SIZE = 768
|
53 |
-
N_SENTIMENT_CLASSES = 5
|
54 |
-
|
55 |
-
|
56 |
-
class MultitaskBERT(nn.Module):
|
57 |
-
'''
|
58 |
-
This module should use BERT for 3 tasks:
|
59 |
-
|
60 |
-
- Sentiment classification (predict_sentiment)
|
61 |
-
- Paraphrase detection (predict_paraphrase)
|
62 |
-
- Semantic Textual Similarity (predict_similarity)
|
63 |
-
'''
|
64 |
-
def __init__(self, config):
|
65 |
-
super(MultitaskBERT, self).__init__()
|
66 |
-
self.bert = BertModel.from_pretrained('bert-base-uncased')
|
67 |
-
# last-linear-layer mode does not require updating BERT paramters.
|
68 |
-
assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
|
69 |
-
for param in self.bert.parameters():
|
70 |
-
if config.fine_tune_mode == 'last-linear-layer':
|
71 |
-
param.requires_grad = False
|
72 |
-
elif config.fine_tune_mode == 'full-model':
|
73 |
-
param.requires_grad = True
|
74 |
-
# You will want to add layers here to perform the downstream tasks.
|
75 |
-
### TODO
|
76 |
-
raise NotImplementedError
|
77 |
-
|
78 |
-
|
79 |
-
def forward(self, input_ids, attention_mask):
|
80 |
-
'Takes a batch of sentences and produces embeddings for them.'
|
81 |
-
# The final BERT embedding is the hidden state of [CLS] token (the first token)
|
82 |
-
# Here, you can start by just returning the embeddings straight from BERT.
|
83 |
-
# When thinking of improvements, you can later try modifying this
|
84 |
-
# (e.g., by adding other layers).
|
85 |
-
### TODO
|
86 |
-
raise NotImplementedError
|
87 |
-
|
88 |
-
|
89 |
-
def predict_sentiment(self, input_ids, attention_mask):
|
90 |
-
'''Given a batch of sentences, outputs logits for classifying sentiment.
|
91 |
-
There are 5 sentiment classes:
|
92 |
-
(0 - negative, 1- somewhat negative, 2- neutral, 3- somewhat positive, 4- positive)
|
93 |
-
Thus, your output should contain 5 logits for each sentence.
|
94 |
-
'''
|
95 |
-
### TODO
|
96 |
-
raise NotImplementedError
|
97 |
-
|
98 |
-
|
99 |
-
def predict_paraphrase(self,
|
100 |
-
input_ids_1, attention_mask_1,
|
101 |
-
input_ids_2, attention_mask_2):
|
102 |
-
'''Given a batch of pairs of sentences, outputs a single logit for predicting whether they are paraphrases.
|
103 |
-
Note that your output should be unnormalized (a logit); it will be passed to the sigmoid function
|
104 |
-
during evaluation.
|
105 |
-
'''
|
106 |
-
### TODO
|
107 |
-
raise NotImplementedError
|
108 |
-
|
109 |
-
|
110 |
-
def predict_similarity(self,
|
111 |
-
input_ids_1, attention_mask_1,
|
112 |
-
input_ids_2, attention_mask_2):
|
113 |
-
'''Given a batch of pairs of sentences, outputs a single logit corresponding to how similar they are.
|
114 |
-
Note that your output should be unnormalized (a logit).
|
115 |
-
'''
|
116 |
-
### TODO
|
117 |
-
raise NotImplementedError
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
def save_model(model, optimizer, args, config, filepath):
|
123 |
-
save_info = {
|
124 |
-
'model': model.state_dict(),
|
125 |
-
'optim': optimizer.state_dict(),
|
126 |
-
'args': args,
|
127 |
-
'model_config': config,
|
128 |
-
'system_rng': random.getstate(),
|
129 |
-
'numpy_rng': np.random.get_state(),
|
130 |
-
'torch_rng': torch.random.get_rng_state(),
|
131 |
-
}
|
132 |
-
|
133 |
-
torch.save(save_info, filepath)
|
134 |
-
print(f"save the model to {filepath}")
|
135 |
-
|
136 |
-
|
137 |
-
def train_multitask(args):
|
138 |
-
'''Train MultitaskBERT.
|
139 |
-
|
140 |
-
Currently only trains on SST dataset. The way you incorporate training examples
|
141 |
-
from other datasets into the training procedure is up to you. To begin, take a
|
142 |
-
look at test_multitask below to see how you can use the custom torch `Dataset`s
|
143 |
-
in datasets.py to load in examples from the Quora and SemEval datasets.
|
144 |
-
'''
|
145 |
-
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
146 |
-
# Create the data and its corresponding datasets and dataloader.
|
147 |
-
sst_train_data, num_labels,para_train_data, sts_train_data = load_multitask_data(args.sst_train,args.para_train,args.sts_train, split ='train')
|
148 |
-
sst_dev_data, num_labels,para_dev_data, sts_dev_data = load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev, split ='train')
|
149 |
-
|
150 |
-
sst_train_data = SentenceClassificationDataset(sst_train_data, args)
|
151 |
-
sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)
|
152 |
-
|
153 |
-
sst_train_dataloader = DataLoader(sst_train_data, shuffle=True, batch_size=args.batch_size,
|
154 |
-
collate_fn=sst_train_data.collate_fn)
|
155 |
-
sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
|
156 |
-
collate_fn=sst_dev_data.collate_fn)
|
157 |
-
|
158 |
-
# Init model.
|
159 |
-
config = {'hidden_dropout_prob': args.hidden_dropout_prob,
|
160 |
-
'num_labels': num_labels,
|
161 |
-
'hidden_size': 768,
|
162 |
-
'data_dir': '.',
|
163 |
-
'fine_tune_mode': args.fine_tune_mode}
|
164 |
-
|
165 |
-
config = SimpleNamespace(**config)
|
166 |
-
|
167 |
-
model = MultitaskBERT(config)
|
168 |
-
model = model.to(device)
|
169 |
-
|
170 |
-
lr = args.lr
|
171 |
-
optimizer = AdamW(model.parameters(), lr=lr)
|
172 |
-
best_dev_acc = 0
|
173 |
-
|
174 |
-
# Run for the specified number of epochs.
|
175 |
-
for epoch in range(args.epochs):
|
176 |
-
model.train()
|
177 |
-
train_loss = 0
|
178 |
-
num_batches = 0
|
179 |
-
for batch in tqdm(sst_train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):
|
180 |
-
b_ids, b_mask, b_labels = (batch['token_ids'],
|
181 |
-
batch['attention_mask'], batch['labels'])
|
182 |
-
|
183 |
-
b_ids = b_ids.to(device)
|
184 |
-
b_mask = b_mask.to(device)
|
185 |
-
b_labels = b_labels.to(device)
|
186 |
-
|
187 |
-
optimizer.zero_grad()
|
188 |
-
logits = model.predict_sentiment(b_ids, b_mask)
|
189 |
-
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
|
190 |
-
|
191 |
-
loss.backward()
|
192 |
-
optimizer.step()
|
193 |
-
|
194 |
-
train_loss += loss.item()
|
195 |
-
num_batches += 1
|
196 |
-
|
197 |
-
train_loss = train_loss / (num_batches)
|
198 |
-
|
199 |
-
train_acc, train_f1, *_ = model_eval_sst(sst_train_dataloader, model, device)
|
200 |
-
dev_acc, dev_f1, *_ = model_eval_sst(sst_dev_dataloader, model, device)
|
201 |
-
|
202 |
-
if dev_acc > best_dev_acc:
|
203 |
-
best_dev_acc = dev_acc
|
204 |
-
save_model(model, optimizer, args, config, args.filepath)
|
205 |
-
|
206 |
-
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
|
207 |
-
|
208 |
-
|
209 |
-
def test_multitask(args):
|
210 |
-
'''Test and save predictions on the dev and test sets of all three tasks.'''
|
211 |
-
with torch.no_grad():
|
212 |
-
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
213 |
-
saved = torch.load(args.filepath)
|
214 |
-
config = saved['model_config']
|
215 |
-
|
216 |
-
model = MultitaskBERT(config)
|
217 |
-
model.load_state_dict(saved['model'])
|
218 |
-
model = model.to(device)
|
219 |
-
print(f"Loaded model to test from {args.filepath}")
|
220 |
-
|
221 |
-
sst_test_data, num_labels,para_test_data, sts_test_data = \
|
222 |
-
load_multitask_data(args.sst_test,args.para_test, args.sts_test, split='test')
|
223 |
-
|
224 |
-
sst_dev_data, num_labels,para_dev_data, sts_dev_data = \
|
225 |
-
load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev,split='dev')
|
226 |
-
|
227 |
-
sst_test_data = SentenceClassificationTestDataset(sst_test_data, args)
|
228 |
-
sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)
|
229 |
-
|
230 |
-
sst_test_dataloader = DataLoader(sst_test_data, shuffle=True, batch_size=args.batch_size,
|
231 |
-
collate_fn=sst_test_data.collate_fn)
|
232 |
-
sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
|
233 |
-
collate_fn=sst_dev_data.collate_fn)
|
234 |
-
|
235 |
-
para_test_data = SentencePairTestDataset(para_test_data, args)
|
236 |
-
para_dev_data = SentencePairDataset(para_dev_data, args)
|
237 |
-
|
238 |
-
para_test_dataloader = DataLoader(para_test_data, shuffle=True, batch_size=args.batch_size,
|
239 |
-
collate_fn=para_test_data.collate_fn)
|
240 |
-
para_dev_dataloader = DataLoader(para_dev_data, shuffle=False, batch_size=args.batch_size,
|
241 |
-
collate_fn=para_dev_data.collate_fn)
|
242 |
-
|
243 |
-
sts_test_data = SentencePairTestDataset(sts_test_data, args)
|
244 |
-
sts_dev_data = SentencePairDataset(sts_dev_data, args, isRegression=True)
|
245 |
-
|
246 |
-
sts_test_dataloader = DataLoader(sts_test_data, shuffle=True, batch_size=args.batch_size,
|
247 |
-
collate_fn=sts_test_data.collate_fn)
|
248 |
-
sts_dev_dataloader = DataLoader(sts_dev_data, shuffle=False, batch_size=args.batch_size,
|
249 |
-
collate_fn=sts_dev_data.collate_fn)
|
250 |
-
|
251 |
-
dev_sentiment_accuracy,dev_sst_y_pred, dev_sst_sent_ids, \
|
252 |
-
dev_paraphrase_accuracy, dev_para_y_pred, dev_para_sent_ids, \
|
253 |
-
dev_sts_corr, dev_sts_y_pred, dev_sts_sent_ids = model_eval_multitask(sst_dev_dataloader,
|
254 |
-
para_dev_dataloader,
|
255 |
-
sts_dev_dataloader, model, device)
|
256 |
-
|
257 |
-
test_sst_y_pred, \
|
258 |
-
test_sst_sent_ids, test_para_y_pred, test_para_sent_ids, test_sts_y_pred, test_sts_sent_ids = \
|
259 |
-
model_eval_test_multitask(sst_test_dataloader,
|
260 |
-
para_test_dataloader,
|
261 |
-
sts_test_dataloader, model, device)
|
262 |
-
|
263 |
-
with open(args.sst_dev_out, "w+") as f:
|
264 |
-
print(f"dev sentiment acc :: {dev_sentiment_accuracy :.3f}")
|
265 |
-
f.write(f"id \t Predicted_Sentiment \n")
|
266 |
-
for p, s in zip(dev_sst_sent_ids, dev_sst_y_pred):
|
267 |
-
f.write(f"{p} , {s} \n")
|
268 |
-
|
269 |
-
with open(args.sst_test_out, "w+") as f:
|
270 |
-
f.write(f"id \t Predicted_Sentiment \n")
|
271 |
-
for p, s in zip(test_sst_sent_ids, test_sst_y_pred):
|
272 |
-
f.write(f"{p} , {s} \n")
|
273 |
-
|
274 |
-
with open(args.para_dev_out, "w+") as f:
|
275 |
-
print(f"dev paraphrase acc :: {dev_paraphrase_accuracy :.3f}")
|
276 |
-
f.write(f"id \t Predicted_Is_Paraphrase \n")
|
277 |
-
for p, s in zip(dev_para_sent_ids, dev_para_y_pred):
|
278 |
-
f.write(f"{p} , {s} \n")
|
279 |
-
|
280 |
-
with open(args.para_test_out, "w+") as f:
|
281 |
-
f.write(f"id \t Predicted_Is_Paraphrase \n")
|
282 |
-
for p, s in zip(test_para_sent_ids, test_para_y_pred):
|
283 |
-
f.write(f"{p} , {s} \n")
|
284 |
-
|
285 |
-
with open(args.sts_dev_out, "w+") as f:
|
286 |
-
print(f"dev sts corr :: {dev_sts_corr :.3f}")
|
287 |
-
f.write(f"id \t Predicted_Similiary \n")
|
288 |
-
for p, s in zip(dev_sts_sent_ids, dev_sts_y_pred):
|
289 |
-
f.write(f"{p} , {s} \n")
|
290 |
-
|
291 |
-
with open(args.sts_test_out, "w+") as f:
|
292 |
-
f.write(f"id \t Predicted_Similiary \n")
|
293 |
-
for p, s in zip(test_sts_sent_ids, test_sts_y_pred):
|
294 |
-
f.write(f"{p} , {s} \n")
|
295 |
-
|
296 |
-
|
297 |
-
def get_args():
|
298 |
-
parser = argparse.ArgumentParser()
|
299 |
-
parser.add_argument("--sst_train", type=str, default="data/ids-sst-train.csv")
|
300 |
-
parser.add_argument("--sst_dev", type=str, default="data/ids-sst-dev.csv")
|
301 |
-
parser.add_argument("--sst_test", type=str, default="data/ids-sst-test-student.csv")
|
302 |
-
|
303 |
-
parser.add_argument("--para_train", type=str, default="data/quora-train.csv")
|
304 |
-
parser.add_argument("--para_dev", type=str, default="data/quora-dev.csv")
|
305 |
-
parser.add_argument("--para_test", type=str, default="data/quora-test-student.csv")
|
306 |
-
|
307 |
-
parser.add_argument("--sts_train", type=str, default="data/sts-train.csv")
|
308 |
-
parser.add_argument("--sts_dev", type=str, default="data/sts-dev.csv")
|
309 |
-
parser.add_argument("--sts_test", type=str, default="data/sts-test-student.csv")
|
310 |
-
|
311 |
-
parser.add_argument("--seed", type=int, default=11711)
|
312 |
-
parser.add_argument("--epochs", type=int, default=10)
|
313 |
-
parser.add_argument("--fine-tune-mode", type=str,
|
314 |
-
help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
|
315 |
-
choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
|
316 |
-
parser.add_argument("--use_gpu", action='store_true')
|
317 |
-
|
318 |
-
parser.add_argument("--sst_dev_out", type=str, default="predictions/sst-dev-output.csv")
|
319 |
-
parser.add_argument("--sst_test_out", type=str, default="predictions/sst-test-output.csv")
|
320 |
-
|
321 |
-
parser.add_argument("--para_dev_out", type=str, default="predictions/para-dev-output.csv")
|
322 |
-
parser.add_argument("--para_test_out", type=str, default="predictions/para-test-output.csv")
|
323 |
-
|
324 |
-
parser.add_argument("--sts_dev_out", type=str, default="predictions/sts-dev-output.csv")
|
325 |
-
parser.add_argument("--sts_test_out", type=str, default="predictions/sts-test-output.csv")
|
326 |
-
|
327 |
-
parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)
|
328 |
-
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
329 |
-
parser.add_argument("--lr", type=float, help="learning rate", default=1e-5)
|
330 |
-
|
331 |
-
args = parser.parse_args()
|
332 |
-
return args
|
333 |
-
|
334 |
-
|
335 |
-
if __name__ == "__main__":
|
336 |
-
args = get_args()
|
337 |
-
args.filepath = f'{args.fine_tune_mode}-{args.epochs}-{args.lr}-multitask.pt' # Save path.
|
338 |
-
seed_everything(args.seed) # Fix the seed for reproducibility.
|
339 |
-
train_multitask(args)
|
340 |
-
test_multitask(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Tôi muốn finetune minBERT bằng phương pháp Unsupervised SimCSE để thực hiện sentiment analysis nhưng chưa biết phải làm như thế nào. theo như tôi hiểu thì tôi sẽ finetune mô hình minBERT bằng SimCSE để có được embeddings tốt hơn, sau đó sẽ dùng embeddings này để truyền qua SentimentClassifier để phân loại. Tuy nhiên, hướng tiếp cận đúng đắn là gì?
|
2 |
+
Tôi đã nghĩ đến hai cách sau đây (hoặc có thể cách khác nhưng chưa nghĩ ra). Bạn xem xét thử nhé!
|
3 |
+
1. Finetune minBERT bằng SimCSE trước rồi mới finetune SentimentClassifier: sử dụng dataset STS-B hoặc Twitter Sentiment Dataset để finetune minBERT, rồi đánh giá độ
|
trainings/last-layer-w-dropout.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
Training Sentiment Classifier on SST...
|
2 |
load 8544 data from data/ids-sst-train.csv
|
3 |
load 1101 data from data/ids-sst-dev.csv
|
4 |
Epoch 0: train loss :: 1.458, train acc :: 0.460, dev acc :: 0.442
|
@@ -14,7 +14,7 @@ Epoch 9: train loss :: 1.227, train acc :: 0.509, dev acc :: 0.475
|
|
14 |
Evaluating on SST...
|
15 |
load model from sst-classifier.pt
|
16 |
load 1101 data from data/ids-sst-dev.csv
|
17 |
-
DONE DEV
|
18 |
DONE Test
|
19 |
dev acc :: 0.475
|
20 |
Training Sentiment Classifier on cfimdb...
|
@@ -33,6 +33,6 @@ Epoch 9: train loss :: 0.407, train acc :: 0.895, dev acc :: 0.873
|
|
33 |
Evaluating on cfimdb...
|
34 |
load model from cfimdb-classifier.pt
|
35 |
load 245 data from data/ids-cfimdb-dev.csv
|
36 |
-
DONE DEV
|
37 |
-
DONE Test
|
38 |
dev acc :: 0.873
|
|
|
1 |
+
Training Sentiment Classifier on SST...
|
2 |
load 8544 data from data/ids-sst-train.csv
|
3 |
load 1101 data from data/ids-sst-dev.csv
|
4 |
Epoch 0: train loss :: 1.458, train acc :: 0.460, dev acc :: 0.442
|
|
|
14 |
Evaluating on SST...
|
15 |
load model from sst-classifier.pt
|
16 |
load 1101 data from data/ids-sst-dev.csv
|
17 |
+
DONE DEV
|
18 |
DONE Test
|
19 |
dev acc :: 0.475
|
20 |
Training Sentiment Classifier on cfimdb...
|
|
|
33 |
Evaluating on cfimdb...
|
34 |
load model from cfimdb-classifier.pt
|
35 |
load 245 data from data/ids-cfimdb-dev.csv
|
36 |
+
DONE DEV
|
37 |
+
DONE Test
|
38 |
dev acc :: 0.873
|
unsup_simcse.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
from types import SimpleNamespace
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
from sklearn.metrics import f1_score, accuracy_score
|
11 |
+
|
12 |
+
from bert import BertModel
|
13 |
+
from optimizer import AdamW
|
14 |
+
from classifier import seed_everything, tokenizer
|
15 |
+
from classifier import SentimentDataset, BertSentimentClassifier
|
16 |
+
|
17 |
+
|
18 |
+
TQDM_DISABLE = False
|
19 |
+
|
20 |
+
|
21 |
+
class TwitterDataset(Dataset):
|
22 |
+
def __init__(self, dataset, args):
|
23 |
+
self.dataset = dataset
|
24 |
+
self.p = args
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
return len(self.dataset)
|
28 |
+
|
29 |
+
def __getitem__(self, idx):
|
30 |
+
return self.dataset[idx]
|
31 |
+
|
32 |
+
def pad_data(self, sents):
|
33 |
+
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
34 |
+
token_ids = torch.LongTensor(encoding['input_ids'])
|
35 |
+
attension_mask = torch.LongTensor(encoding['attention_mask'])
|
36 |
+
|
37 |
+
return token_ids, attension_mask
|
38 |
+
|
39 |
+
def collate_fn(self, sents):
|
40 |
+
token_ids, attention_mask = self.pad_data(sents)
|
41 |
+
|
42 |
+
batched_data = {
|
43 |
+
'token_ids': token_ids,
|
44 |
+
'attention_mask': attention_mask,
|
45 |
+
}
|
46 |
+
|
47 |
+
return batched_data
|
48 |
+
|
49 |
+
|
50 |
+
def load_data(filename, flag='train'):
|
51 |
+
'''
|
52 |
+
- for Twitter dataset: list of sentences
|
53 |
+
- for SST/CFIMDB dataset: list of (sent, [label], sent_id)
|
54 |
+
'''
|
55 |
+
num_labels = set()
|
56 |
+
data = []
|
57 |
+
with open(filename, 'r') as fp:
|
58 |
+
for record in csv.DictReader(fp, delimiter = ',', ):
|
59 |
+
if flag == 'twitter':
|
60 |
+
sent = record['clean_text'].lower().strip()
|
61 |
+
data.append(sent)
|
62 |
+
elif flag == 'test':
|
63 |
+
sent = record['sentence'].lower().strip()
|
64 |
+
sent_id = record['id'].lower().strip()
|
65 |
+
data.append((sent,sent_id))
|
66 |
+
else:
|
67 |
+
sent = record['sentence'].lower().strip()
|
68 |
+
sent_id = record['id'].lower().strip()
|
69 |
+
label = int(record['sentiment'].strip())
|
70 |
+
num_labels.add(label)
|
71 |
+
data.append((sent, label, sent_id))
|
72 |
+
print(f"load {len(data)} data from {filename}")
|
73 |
+
|
74 |
+
if flag == 'train':
|
75 |
+
return data, len(num_labels)
|
76 |
+
else:
|
77 |
+
return data
|
78 |
+
|
79 |
+
|
80 |
+
def save_model(model, optimizer, args, config, filepath):
|
81 |
+
save_info = {
|
82 |
+
'model': model.state_dict(),
|
83 |
+
'optim': optimizer.state_dict(),
|
84 |
+
'args': args,
|
85 |
+
'model_config': config,
|
86 |
+
'system_rng': random.getstate(),
|
87 |
+
'numpy_rng': np.random.get_state(),
|
88 |
+
'torch_rng': torch.random.get_rng_state(),
|
89 |
+
}
|
90 |
+
|
91 |
+
torch.save(save_info, filepath)
|
92 |
+
print(f"save the model to {filepath}")
|
93 |
+
|
94 |
+
|
95 |
+
def train(args):
|
96 |
+
'''
|
97 |
+
Training Pipeline
|
98 |
+
-----------------
|
99 |
+
1. Load the Twitter Sentiment and SST Dataset.
|
100 |
+
2. Determine batch_size (64) and number of batches (?).
|
101 |
+
3. Initialize SentimentClassifier (including bert).
|
102 |
+
4. Looping through 10 epoches.
|
103 |
+
5. Finetune minBERT with SimCSE loss function.
|
104 |
+
6. Finetune Classifier with cross-entropy function.
|
105 |
+
7. Backpropagation using Adam Optimizer for both.
|
106 |
+
8. Evaluating the model on dev dataset.
|
107 |
+
9. If dev_acc > best_dev_acc: save_model(...)
|
108 |
+
'''
|
109 |
+
|
110 |
+
twitter_data = load_data(args.train_bert, 'twitter')
|
111 |
+
train_data, num_labels = load_data(args.train, 'train')
|
112 |
+
dev_data = load_data(args.dev, 'valid')
|
113 |
+
|
114 |
+
twitter_dataset = TwitterDataset(twitter_data, args)
|
115 |
+
train_dataset = SentimentDataset(train_data, args)
|
116 |
+
dev_dataset = SentimentDataset(dev_data, args)
|
117 |
+
|
118 |
+
twitter_dataloader = DataLoader(twitter_dataset, shuffle=True, batch_size=args.batch_size_cse,
|
119 |
+
num_workers=args.num_cpu_cores, collate_fn=twitter_dataset.collate_fn)
|
120 |
+
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
|
121 |
+
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
|
122 |
+
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
|
123 |
+
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
|
124 |
+
|
125 |
+
config = SimpleNamespace(
|
126 |
+
hidden_dropout_prob=args.hidden_dropout_prob,
|
127 |
+
num_labels=num_labels,
|
128 |
+
hidden_size=768,
|
129 |
+
data_dir='.',
|
130 |
+
fine_tune_mode='full-model'
|
131 |
+
)
|
132 |
+
|
133 |
+
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
134 |
+
model = BertSentimentClassifier(config)
|
135 |
+
model = model.to(device)
|
136 |
+
|
137 |
+
optimizer_cse = AdamW(model.bert.parameters(), lr=args.lr_cse)
|
138 |
+
optimizer_classifier = AdamW(model.parameters(), lr=args.lr_classifier)
|
139 |
+
best_dev_acc = 0
|
140 |
+
|
141 |
+
for epoch in range(args.epochs):
|
142 |
+
model.bert.train()
|
143 |
+
train_loss = num_batches = 0
|
144 |
+
for batch in tqdm(twitter_dataloader, f'train-twitter-{epoch}', leave=False, disable=TQDM_DISABLE):
|
145 |
+
b_ids, b_mask = batch['token_ids'], batch['attention_mask']
|
146 |
+
b_ids = b_ids.to(device)
|
147 |
+
b_mask = b_mask.to(device)
|
148 |
+
|
149 |
+
optimizer_cse.zero_grad()
|
150 |
+
logits = model.bert.embed(b_ids)
|
151 |
+
logits = model.bert.encode(logits, b_mask)
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
def get_args():
|
157 |
+
parser = argparse.ArgumentParser()
|
158 |
+
parser.add_argument("--seed", type=int, default=11711)
|
159 |
+
parser.add_argument("--num-cpu-cores", type=int, default=4)
|
160 |
+
parser.add_argument("--epochs", type=int, default=10)
|
161 |
+
parser.add_argument("--use_gpu", action='store_true')
|
162 |
+
parser.add_argument("--batch_size_cse", help="'unsup': 64, 'sup': 512", type=int)
|
163 |
+
parser.add_argument("--batch_size_classifier", help="'sst': 64, 'cfimdb': 8", type=int)
|
164 |
+
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
165 |
+
parser.add_argument("--lr_cse", default=2e-5)
|
166 |
+
parser.add_argument("--lr_classifier", default=1e-5)
|
167 |
+
|
168 |
+
args = parser.parse_args()
|
169 |
+
return args
|
170 |
+
|
171 |
+
|
172 |
+
if __name__ == "__main__":
|
173 |
+
args = get_args()
|
174 |
+
seed_everything(args.seed)
|
175 |
+
torch.set_num_threads(args.num_cpu_cores)
|
176 |
+
|
177 |
+
print('Finetuning minBERT with Unsupervised SimCSE...')
|
178 |
+
config = SimpleNamespace(
|
179 |
+
filepath='contrastive-nli.pt',
|
180 |
+
lr=args.lr,
|
181 |
+
num_cpu_cores=args.num_cpu_cores,
|
182 |
+
use_gpu=args.use_gpu,
|
183 |
+
epochs=args.epochs,
|
184 |
+
batch_size_cse=args.batch_size_cse,
|
185 |
+
batch_size_classifier=args.batch_size_classifier,
|
186 |
+
train_bert='data/twitter-unsup.csv',
|
187 |
+
train='data/ids-sst-train.csv',
|
188 |
+
dev='data/ids-sst-dev.csv',
|
189 |
+
test='data/ids-sst-test-student.csv',
|
190 |
+
dev_out = 'predictions/' + args.fine_tune_mode + '-sst-dev-out.csv',
|
191 |
+
test_out = 'predictions/' + args.fine_tune_mode + '-sst-test-out.csv'
|
192 |
+
)
|
193 |
+
|
194 |
+
train(config)
|
195 |
+
|
196 |
+
# model = BertModel.from_pretrained('bert-base-uncased')
|
197 |
+
|
198 |
+
# model.eval()
|
199 |
+
|
200 |
+
# s = set()
|
201 |
+
# for param in model.parameters():
|
202 |
+
# s.add(param.requires_grad)
|
203 |
+
|
204 |
+
# print(s)
|