GlowCheese commited on
Commit
7587354
·
1 Parent(s): 51eaae4

contrastive commit 1

Browse files
base_bert.py CHANGED
@@ -5,244 +5,244 @@ from utils import *
5
 
6
 
7
  class BertPreTrainedModel(nn.Module):
8
- config_class = BertConfig
9
- base_model_prefix = "bert"
10
- _keys_to_ignore_on_load_missing = [r"position_ids"]
11
- _keys_to_ignore_on_load_unexpected = None
12
-
13
- def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
14
- super().__init__()
15
- self.config = config
16
- self.name_or_path = config.name_or_path
17
-
18
- def init_weights(self):
19
- # Initialize weights
20
- self.apply(self._init_weights)
21
-
22
- def _init_weights(self, module):
23
- """ Initialize the weights """
24
- if isinstance(module, (nn.Linear, nn.Embedding)):
25
- # Slightly different from the TF version which uses truncated_normal for initialization
26
- # cf https://github.com/pytorch/pytorch/pull/5617
27
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
28
- elif isinstance(module, nn.LayerNorm):
29
- module.bias.data.zero_()
30
- module.weight.data.fill_(1.0)
31
- if isinstance(module, nn.Linear) and module.bias is not None:
32
- module.bias.data.zero_()
33
-
34
- @property
35
- def dtype(self) -> dtype:
36
- return get_parameter_dtype(self)
37
-
38
- @classmethod
39
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
40
- config = kwargs.pop("config", None)
41
- state_dict = kwargs.pop("state_dict", None)
42
- cache_dir = kwargs.pop("cache_dir", None)
43
- force_download = kwargs.pop("force_download", False)
44
- resume_download = kwargs.pop("resume_download", False)
45
- proxies = kwargs.pop("proxies", None)
46
- output_loading_info = kwargs.pop("output_loading_info", False)
47
- local_files_only = kwargs.pop("local_files_only", False)
48
- use_auth_token = kwargs.pop("use_auth_token", None)
49
- revision = kwargs.pop("revision", None)
50
- mirror = kwargs.pop("mirror", None)
51
-
52
- # Load config if we don't provide a configuration
53
- if not isinstance(config, PretrainedConfig):
54
- config_path = config if config is not None else pretrained_model_name_or_path
55
- config, model_kwargs = cls.config_class.from_pretrained(
56
- config_path,
57
- *model_args,
58
- cache_dir=cache_dir,
59
- return_unused_kwargs=True,
60
- force_download=force_download,
61
- resume_download=resume_download,
62
- proxies=proxies,
63
- local_files_only=local_files_only,
64
- use_auth_token=use_auth_token,
65
- revision=revision,
66
- **kwargs,
67
- )
68
- else:
69
- model_kwargs = kwargs
70
-
71
- # Load model
72
- if pretrained_model_name_or_path is not None:
73
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
74
- if os.path.isdir(pretrained_model_name_or_path):
75
- # Load from a PyTorch checkpoint
76
- archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
77
- elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
78
- archive_file = pretrained_model_name_or_path
79
- else:
80
- archive_file = hf_bucket_url(
81
- pretrained_model_name_or_path,
82
- filename=WEIGHTS_NAME,
83
- revision=revision,
84
- mirror=mirror,
85
- )
86
- try:
87
- # Load from URL or cache if already cached
88
- resolved_archive_file = cached_path(
89
- archive_file,
90
- cache_dir=cache_dir,
91
- force_download=force_download,
92
- proxies=proxies,
93
- resume_download=resume_download,
94
- local_files_only=local_files_only,
95
- use_auth_token=use_auth_token,
96
- )
97
- except EnvironmentError as err:
98
- #logger.error(err)
99
- msg = (
100
- f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
101
- f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
102
- f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
103
- )
104
- raise EnvironmentError(msg)
105
- else:
106
- resolved_archive_file = None
107
-
108
- config.name_or_path = pretrained_model_name_or_path
109
-
110
- # Instantiate model.
111
- model = cls(config, *model_args, **model_kwargs)
112
-
113
- if state_dict is None:
114
- try:
115
- state_dict = torch.load(resolved_archive_file, map_location="cpu", weights_only=True)
116
- except Exception:
117
- raise OSError(
118
- f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
119
- f"at '{resolved_archive_file}'"
120
- )
121
-
122
- missing_keys = []
123
- unexpected_keys = []
124
- error_msgs = []
125
-
126
- # Convert old format to new format if needed from a PyTorch state_dict
127
- old_keys = []
128
- new_keys = []
129
- m = {'embeddings.word_embeddings': 'word_embedding',
130
- 'embeddings.position_embeddings': 'pos_embedding',
131
- 'embeddings.token_type_embeddings': 'tk_type_embedding',
132
- 'embeddings.LayerNorm': 'embed_layer_norm',
133
- 'embeddings.dropout': 'embed_dropout',
134
- 'encoder.layer': 'bert_layers',
135
- 'pooler.dense': 'pooler_dense',
136
- 'pooler.activation': 'pooler_af',
137
- 'attention.self': "self_attention",
138
- 'attention.output.dense': 'attention_dense',
139
- 'attention.output.LayerNorm': 'attention_layer_norm',
140
- 'attention.output.dropout': 'attention_dropout',
141
- 'intermediate.dense': 'interm_dense',
142
- 'intermediate.intermediate_act_fn': 'interm_af',
143
- 'output.dense': 'out_dense',
144
- 'output.LayerNorm': 'out_layer_norm',
145
- 'output.dropout': 'out_dropout'}
146
-
147
- for key in state_dict.keys():
148
- new_key = None
149
- if "gamma" in key:
150
- new_key = key.replace("gamma", "weight")
151
- if "beta" in key:
152
- new_key = key.replace("beta", "bias")
153
- for x, y in m.items():
154
- if new_key is not None:
155
- _key = new_key
156
  else:
157
- _key = key
158
- if x in key:
159
- new_key = _key.replace(x, y)
160
- if new_key:
161
- old_keys.append(key)
162
- new_keys.append(new_key)
163
-
164
- for old_key, new_key in zip(old_keys, new_keys):
165
- # print(old_key, new_key)
166
- state_dict[new_key] = state_dict.pop(old_key)
167
-
168
- # copy state_dict so _load_from_state_dict can modify it
169
- metadata = getattr(state_dict, "_metadata", None)
170
- state_dict = state_dict.copy()
171
- if metadata is not None:
172
- state_dict._metadata = metadata
173
-
174
- your_bert_params = [f"bert.{x[0]}" for x in model.named_parameters()]
175
- for k in state_dict:
176
- if k not in your_bert_params and not k.startswith("cls."):
177
- possible_rename = [x for x in k.split(".")[1:-1] if x in m.values()]
178
- raise ValueError(f"{k} cannot be reload to your model, one/some of {possible_rename} we provided have been renamed")
179
-
180
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
181
- # so we need to apply the function recursively.
182
- def load(module: nn.Module, prefix=""):
183
- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
184
- module._load_from_state_dict(
185
- state_dict,
186
- prefix,
187
- local_metadata,
188
- True,
189
- missing_keys,
190
- unexpected_keys,
191
- error_msgs,
192
- )
193
- for name, child in module._modules.items():
194
- if child is not None:
195
- load(child, prefix + name + ".")
196
-
197
- # Make sure we are able to load base models as well as derived models (with heads)
198
- start_prefix = ""
199
- model_to_load = model
200
- has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
201
- if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
202
- start_prefix = cls.base_model_prefix + "."
203
- if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
204
- model_to_load = getattr(model, cls.base_model_prefix)
205
- load(model_to_load, prefix=start_prefix)
206
-
207
- if model.__class__.__name__ != model_to_load.__class__.__name__:
208
- base_model_state_dict = model_to_load.state_dict().keys()
209
- head_model_state_dict_without_base_prefix = [
210
- key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
211
- ]
212
- missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
213
-
214
- # Some models may have keys that are not in the state by design, removing them before needlessly warning
215
- # the user.
216
- if cls._keys_to_ignore_on_load_missing is not None:
217
- for pat in cls._keys_to_ignore_on_load_missing:
218
- missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
219
-
220
- if cls._keys_to_ignore_on_load_unexpected is not None:
221
- for pat in cls._keys_to_ignore_on_load_unexpected:
222
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
223
-
224
- if len(error_msgs) > 0:
225
- raise RuntimeError(
226
- "Error(s) in loading state_dict for {}:\n\t{}".format(
227
- model.__class__.__name__, "\n\t".join(error_msgs)
228
- )
229
- )
230
-
231
- # Set model in evaluation mode to deactivate DropOut modules by default
232
- model.eval()
233
-
234
- if output_loading_info:
235
- loading_info = {
236
- "missing_keys": missing_keys,
237
- "unexpected_keys": unexpected_keys,
238
- "error_msgs": error_msgs,
239
- }
240
- return model, loading_info
241
-
242
- if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
243
- import torch_xla.core.xla_model as xm
244
-
245
- model = xm.send_cpu_data_to_device(model, xm.xla_device())
246
- model.to(xm.xla_device())
247
-
248
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  class BertPreTrainedModel(nn.Module):
8
+ config_class = BertConfig
9
+ base_model_prefix = "bert"
10
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
11
+ _keys_to_ignore_on_load_unexpected = None
12
+
13
+ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
14
+ super().__init__()
15
+ self.config = config
16
+ self.name_or_path = config.name_or_path
17
+
18
+ def init_weights(self):
19
+ # Initialize weights
20
+ self.apply(self._init_weights)
21
+
22
+ def _init_weights(self, module):
23
+ """ Initialize the weights """
24
+ if isinstance(module, (nn.Linear, nn.Embedding)):
25
+ # Slightly different from the TF version which uses truncated_normal for initialization
26
+ # cf https://github.com/pytorch/pytorch/pull/5617
27
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
28
+ elif isinstance(module, nn.LayerNorm):
29
+ module.bias.data.zero_()
30
+ module.weight.data.fill_(1.0)
31
+ if isinstance(module, nn.Linear) and module.bias is not None:
32
+ module.bias.data.zero_()
33
+
34
+ @property
35
+ def dtype(self) -> dtype:
36
+ return get_parameter_dtype(self)
37
+
38
+ @classmethod
39
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
40
+ config = kwargs.pop("config", None)
41
+ state_dict = kwargs.pop("state_dict", None)
42
+ cache_dir = kwargs.pop("cache_dir", None)
43
+ force_download = kwargs.pop("force_download", False)
44
+ resume_download = kwargs.pop("resume_download", False)
45
+ proxies = kwargs.pop("proxies", None)
46
+ output_loading_info = kwargs.pop("output_loading_info", False)
47
+ local_files_only = kwargs.pop("local_files_only", False)
48
+ use_auth_token = kwargs.pop("use_auth_token", None)
49
+ revision = kwargs.pop("revision", None)
50
+ mirror = kwargs.pop("mirror", None)
51
+
52
+ # Load config if we don't provide a configuration
53
+ if not isinstance(config, PretrainedConfig):
54
+ config_path = config if config is not None else pretrained_model_name_or_path
55
+ config, model_kwargs = cls.config_class.from_pretrained(
56
+ config_path,
57
+ *model_args,
58
+ cache_dir=cache_dir,
59
+ return_unused_kwargs=True,
60
+ force_download=force_download,
61
+ resume_download=resume_download,
62
+ proxies=proxies,
63
+ local_files_only=local_files_only,
64
+ use_auth_token=use_auth_token,
65
+ revision=revision,
66
+ **kwargs,
67
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  else:
69
+ model_kwargs = kwargs
70
+
71
+ # Load model
72
+ if pretrained_model_name_or_path is not None:
73
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
74
+ if os.path.isdir(pretrained_model_name_or_path):
75
+ # Load from a PyTorch checkpoint
76
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
77
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
78
+ archive_file = pretrained_model_name_or_path
79
+ else:
80
+ archive_file = hf_bucket_url(
81
+ pretrained_model_name_or_path,
82
+ filename=WEIGHTS_NAME,
83
+ revision=revision,
84
+ mirror=mirror,
85
+ )
86
+ try:
87
+ # Load from URL or cache if already cached
88
+ resolved_archive_file = cached_path(
89
+ archive_file,
90
+ cache_dir=cache_dir,
91
+ force_download=force_download,
92
+ proxies=proxies,
93
+ resume_download=resume_download,
94
+ local_files_only=local_files_only,
95
+ use_auth_token=use_auth_token,
96
+ )
97
+ except EnvironmentError as err:
98
+ #logger.error(err)
99
+ msg = (
100
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
101
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
102
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
103
+ )
104
+ raise EnvironmentError(msg)
105
+ else:
106
+ resolved_archive_file = None
107
+
108
+ config.name_or_path = pretrained_model_name_or_path
109
+
110
+ # Instantiate model.
111
+ model = cls(config, *model_args, **model_kwargs)
112
+
113
+ if state_dict is None:
114
+ try:
115
+ state_dict = torch.load(resolved_archive_file, map_location="cpu", weights_only=True)
116
+ except Exception:
117
+ raise OSError(
118
+ f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
119
+ f"at '{resolved_archive_file}'"
120
+ )
121
+
122
+ missing_keys = []
123
+ unexpected_keys = []
124
+ error_msgs = []
125
+
126
+ # Convert old format to new format if needed from a PyTorch state_dict
127
+ old_keys = []
128
+ new_keys = []
129
+ m = {'embeddings.word_embeddings': 'word_embedding',
130
+ 'embeddings.position_embeddings': 'pos_embedding',
131
+ 'embeddings.token_type_embeddings': 'tk_type_embedding',
132
+ 'embeddings.LayerNorm': 'embed_layer_norm',
133
+ 'embeddings.dropout': 'embed_dropout',
134
+ 'encoder.layer': 'bert_layers',
135
+ 'pooler.dense': 'pooler_dense',
136
+ 'pooler.activation': 'pooler_af',
137
+ 'attention.self': "self_attention",
138
+ 'attention.output.dense': 'attention_dense',
139
+ 'attention.output.LayerNorm': 'attention_layer_norm',
140
+ 'attention.output.dropout': 'attention_dropout',
141
+ 'intermediate.dense': 'interm_dense',
142
+ 'intermediate.intermediate_act_fn': 'interm_af',
143
+ 'output.dense': 'out_dense',
144
+ 'output.LayerNorm': 'out_layer_norm',
145
+ 'output.dropout': 'out_dropout'}
146
+
147
+ for key in state_dict.keys():
148
+ new_key = None
149
+ if "gamma" in key:
150
+ new_key = key.replace("gamma", "weight")
151
+ if "beta" in key:
152
+ new_key = key.replace("beta", "bias")
153
+ for x, y in m.items():
154
+ if new_key is not None:
155
+ _key = new_key
156
+ else:
157
+ _key = key
158
+ if x in key:
159
+ new_key = _key.replace(x, y)
160
+ if new_key:
161
+ old_keys.append(key)
162
+ new_keys.append(new_key)
163
+
164
+ for old_key, new_key in zip(old_keys, new_keys):
165
+ # print(old_key, new_key)
166
+ state_dict[new_key] = state_dict.pop(old_key)
167
+
168
+ # copy state_dict so _load_from_state_dict can modify it
169
+ metadata = getattr(state_dict, "_metadata", None)
170
+ state_dict = state_dict.copy()
171
+ if metadata is not None:
172
+ state_dict._metadata = metadata
173
+
174
+ your_bert_params = [f"bert.{x[0]}" for x in model.named_parameters()]
175
+ for k in state_dict:
176
+ if k not in your_bert_params and not k.startswith("cls."):
177
+ possible_rename = [x for x in k.split(".")[1:-1] if x in m.values()]
178
+ raise ValueError(f"{k} cannot be reload to your model, one/some of {possible_rename} we provided have been renamed")
179
+
180
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
181
+ # so we need to apply the function recursively.
182
+ def load(module: nn.Module, prefix=""):
183
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
184
+ module._load_from_state_dict(
185
+ state_dict,
186
+ prefix,
187
+ local_metadata,
188
+ True,
189
+ missing_keys,
190
+ unexpected_keys,
191
+ error_msgs,
192
+ )
193
+ for name, child in module._modules.items():
194
+ if child is not None:
195
+ load(child, prefix + name + ".")
196
+
197
+ # Make sure we are able to load base models as well as derived models (with heads)
198
+ start_prefix = ""
199
+ model_to_load = model
200
+ has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
201
+ if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
202
+ start_prefix = cls.base_model_prefix + "."
203
+ if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
204
+ model_to_load = getattr(model, cls.base_model_prefix)
205
+ load(model_to_load, prefix=start_prefix)
206
+
207
+ if model.__class__.__name__ != model_to_load.__class__.__name__:
208
+ base_model_state_dict = model_to_load.state_dict().keys()
209
+ head_model_state_dict_without_base_prefix = [
210
+ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
211
+ ]
212
+ missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
213
+
214
+ # Some models may have keys that are not in the state by design, removing them before needlessly warning
215
+ # the user.
216
+ if cls._keys_to_ignore_on_load_missing is not None:
217
+ for pat in cls._keys_to_ignore_on_load_missing:
218
+ missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
219
+
220
+ if cls._keys_to_ignore_on_load_unexpected is not None:
221
+ for pat in cls._keys_to_ignore_on_load_unexpected:
222
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
223
+
224
+ if len(error_msgs) > 0:
225
+ raise RuntimeError(
226
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
227
+ model.__class__.__name__, "\n\t".join(error_msgs)
228
+ )
229
+ )
230
+
231
+ # Set model in evaluation mode to deactivate DropOut modules by default
232
+ model.eval()
233
+
234
+ if output_loading_info:
235
+ loading_info = {
236
+ "missing_keys": missing_keys,
237
+ "unexpected_keys": unexpected_keys,
238
+ "error_msgs": error_msgs,
239
+ }
240
+ return model, loading_info
241
+
242
+ if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
243
+ import torch_xla.core.xla_model as xm
244
+
245
+ model = xm.send_cpu_data_to_device(model, xm.xla_device())
246
+ model.to(xm.xla_device())
247
+
248
+ return model
classifier.py CHANGED
@@ -3,6 +3,7 @@ from types import SimpleNamespace
3
  import csv
4
 
5
  import torch
 
6
  import torch.nn.functional as F
7
  from torch.utils.data import Dataset, DataLoader
8
  from sklearn.metrics import f1_score, accuracy_score
@@ -10,7 +11,6 @@ from sklearn.metrics import f1_score, accuracy_score
10
  from tokenizer import BertTokenizer
11
  from bert import BertModel
12
  from optimizer import AdamW
13
- from tqdm import tqdm
14
 
15
 
16
  TQDM_DISABLE=False
@@ -34,10 +34,10 @@ class BertSentimentClassifier(torch.nn.Module):
34
  In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").
35
  Thus, your forward() should return one logit for each of the 5 classes.
36
  '''
37
- def __init__(self, config):
38
  super(BertSentimentClassifier, self).__init__()
39
  self.num_labels = config.num_labels
40
- self.bert: BertModel = BertModel.from_pretrained('bert-base-uncased')
41
 
42
  # Pretrain mode does not require updating BERT paramters.
43
  assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
@@ -59,26 +59,21 @@ class BertSentimentClassifier(torch.nn.Module):
59
  # the training loop currently uses F.cross_entropy as the loss function.
60
 
61
  # Get the embedding for each input token.
62
- embedding_output = self.bert.embed(input_ids=input_ids)
63
-
64
- # Feed to a transformer (BERT layers).
65
- sequence_output = self.bert.encode(embedding_output, attention_mask=attention_mask)
66
-
67
- # The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).
68
- cls_token_output = sequence_output[:, 0, :] # The first token is [CLS]
69
 
70
  # Pass the [CLS] token representation through the classifier.
71
- logits = self.classifier(self.dropout(cls_token_output))
72
 
73
  return logits
74
 
75
 
 
76
 
77
  class SentimentDataset(Dataset):
78
  def __init__(self, dataset, args):
79
  self.dataset = dataset
80
  self.p = args
81
- self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
82
 
83
  def __len__(self):
84
  return len(self.dataset)
@@ -91,7 +86,7 @@ class SentimentDataset(Dataset):
91
  labels = [x[1] for x in data]
92
  sent_ids = [x[2] for x in data]
93
 
94
- encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
95
  token_ids = torch.LongTensor(encoding['input_ids'])
96
  attention_mask = torch.LongTensor(encoding['attention_mask'])
97
  labels = torch.LongTensor(labels)
@@ -99,15 +94,15 @@ class SentimentDataset(Dataset):
99
  return token_ids, attention_mask, labels, sents, sent_ids
100
 
101
  def collate_fn(self, all_data):
102
- token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
103
 
104
  batched_data = {
105
- 'token_ids': token_ids,
106
- 'attention_mask': attention_mask,
107
- 'labels': labels,
108
- 'sents': sents,
109
- 'sent_ids': sent_ids
110
- }
111
 
112
  return batched_data
113
 
@@ -116,7 +111,6 @@ class SentimentTestDataset(Dataset):
116
  def __init__(self, dataset, args):
117
  self.dataset = dataset
118
  self.p = args
119
- self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
120
 
121
  def __len__(self):
122
  return len(self.dataset)
@@ -128,7 +122,7 @@ class SentimentTestDataset(Dataset):
128
  sents = [x[0] for x in data]
129
  sent_ids = [x[1] for x in data]
130
 
131
- encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
132
  token_ids = torch.LongTensor(encoding['input_ids'])
133
  attention_mask = torch.LongTensor(encoding['attention_mask'])
134
 
@@ -138,34 +132,31 @@ class SentimentTestDataset(Dataset):
138
  token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
139
 
140
  batched_data = {
141
- 'token_ids': token_ids,
142
- 'attention_mask': attention_mask,
143
- 'sents': sents,
144
- 'sent_ids': sent_ids
145
- }
146
 
147
  return batched_data
148
 
149
 
150
  # Load the data: a list of (sentence, label).
151
  def load_data(filename, flag='train'):
152
- num_labels = {}
153
  data = []
154
- if flag == 'test':
155
- with open(filename, 'r') as fp:
156
- for record in csv.DictReader(fp,delimiter = '\t'):
157
  sent = record['sentence'].lower().strip()
158
  sent_id = record['id'].lower().strip()
159
  data.append((sent,sent_id))
160
- else:
161
- with open(filename, 'r') as fp:
162
- for record in csv.DictReader(fp,delimiter = '\t'):
163
  sent = record['sentence'].lower().strip()
164
  sent_id = record['id'].lower().strip()
165
  label = int(record['sentiment'].strip())
166
- if label not in num_labels:
167
- num_labels[label] = len(num_labels)
168
- data.append((sent, label,sent_id))
169
  print(f"load {len(data)} data from {filename}")
170
 
171
  if flag == 'train':
@@ -253,9 +244,9 @@ def train(args):
253
  dev_dataset = SentimentDataset(dev_data, args)
254
 
255
  train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
256
- collate_fn=train_dataset.collate_fn)
257
  dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
258
- collate_fn=dev_dataset.collate_fn)
259
 
260
  # Init model.
261
  config = {'hidden_dropout_prob': args.hidden_dropout_prob,
@@ -311,7 +302,7 @@ def train(args):
311
  def test(args):
312
  with torch.no_grad():
313
  device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
314
- saved = torch.load(args.filepath)
315
  config = saved['model_config']
316
  model = BertSentimentClassifier(config)
317
  model.load_state_dict(saved['model'])
@@ -320,38 +311,44 @@ def test(args):
320
 
321
  dev_data = load_data(args.dev, 'valid')
322
  dev_dataset = SentimentDataset(dev_data, args)
323
- dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=dev_dataset.collate_fn)
 
324
 
325
- test_data = load_data(args.test, 'test')
326
- test_dataset = SentimentTestDataset(test_data, args)
327
- test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn)
328
-
329
  dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
330
  print('DONE DEV')
331
- test_pred, test_sents, test_sent_ids = model_test_eval(test_dataloader, model, device)
332
- print('DONE Test')
333
- with open(args.dev_out, "w+") as f:
334
- print(f"dev acc :: {dev_acc :.3f}")
335
- f.write(f"id \t Predicted_Sentiment \n")
336
- for p, s in zip(dev_sent_ids,dev_pred ):
337
- f.write(f"{p} , {s} \n")
338
-
339
- with open(args.test_out, "w+") as f:
340
- f.write(f"id \t Predicted_Sentiment \n")
341
- for p, s in zip(test_sent_ids,test_pred ):
342
- f.write(f"{p} , {s} \n")
 
 
 
 
 
 
 
343
 
344
 
345
  def get_args():
346
  parser = argparse.ArgumentParser()
347
  parser.add_argument("--seed", type=int, default=11711)
 
348
  parser.add_argument("--epochs", type=int, default=10)
349
  parser.add_argument("--fine-tune-mode", type=str,
350
  help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
351
  choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
352
  parser.add_argument("--use_gpu", action='store_true')
353
 
354
- parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)
 
355
  parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
356
  parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
357
  default=1e-3)
@@ -360,17 +357,21 @@ def get_args():
360
  return args
361
 
362
 
363
- if __name__ == "__main__":
364
  args = get_args()
365
  seed_everything(args.seed)
 
 
 
366
 
367
  print('Training Sentiment Classifier on SST...')
368
  config = SimpleNamespace(
369
  filepath='sst-classifier.pt',
370
  lr=args.lr,
 
371
  use_gpu=args.use_gpu,
372
  epochs=args.epochs,
373
- batch_size=args.batch_size,
374
  hidden_dropout_prob=args.hidden_dropout_prob,
375
  train='data/ids-sst-train.csv',
376
  dev='data/ids-sst-dev.csv',
@@ -389,9 +390,10 @@ if __name__ == "__main__":
389
  config = SimpleNamespace(
390
  filepath='cfimdb-classifier.pt',
391
  lr=args.lr,
 
392
  use_gpu=args.use_gpu,
393
  epochs=args.epochs,
394
- batch_size=8,
395
  hidden_dropout_prob=args.hidden_dropout_prob,
396
  train='data/ids-cfimdb-train.csv',
397
  dev='data/ids-cfimdb-dev.csv',
@@ -405,3 +407,7 @@ if __name__ == "__main__":
405
 
406
  print('Evaluating on cfimdb...')
407
  test(config)
 
 
 
 
 
3
  import csv
4
 
5
  import torch
6
+ from tqdm import tqdm
7
  import torch.nn.functional as F
8
  from torch.utils.data import Dataset, DataLoader
9
  from sklearn.metrics import f1_score, accuracy_score
 
11
  from tokenizer import BertTokenizer
12
  from bert import BertModel
13
  from optimizer import AdamW
 
14
 
15
 
16
  TQDM_DISABLE=False
 
34
  In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").
35
  Thus, your forward() should return one logit for each of the 5 classes.
36
  '''
37
+ def __init__(self, config, bert_model = None):
38
  super(BertSentimentClassifier, self).__init__()
39
  self.num_labels = config.num_labels
40
+ self.bert: BertModel = bert_model or BertModel.from_pretrained('bert-base-uncased')
41
 
42
  # Pretrain mode does not require updating BERT paramters.
43
  assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
 
59
  # the training loop currently uses F.cross_entropy as the loss function.
60
 
61
  # Get the embedding for each input token.
62
+ outputs = self.bert(input_ids, attention_mask)
63
+ pooler_output = outputs['pooler_output']
 
 
 
 
 
64
 
65
  # Pass the [CLS] token representation through the classifier.
66
+ logits = self.classifier(self.dropout(pooler_output))
67
 
68
  return logits
69
 
70
 
71
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
72
 
73
  class SentimentDataset(Dataset):
74
  def __init__(self, dataset, args):
75
  self.dataset = dataset
76
  self.p = args
 
77
 
78
  def __len__(self):
79
  return len(self.dataset)
 
86
  labels = [x[1] for x in data]
87
  sent_ids = [x[2] for x in data]
88
 
89
+ encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
90
  token_ids = torch.LongTensor(encoding['input_ids'])
91
  attention_mask = torch.LongTensor(encoding['attention_mask'])
92
  labels = torch.LongTensor(labels)
 
94
  return token_ids, attention_mask, labels, sents, sent_ids
95
 
96
  def collate_fn(self, all_data):
97
+ token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data)
98
 
99
  batched_data = {
100
+ 'token_ids': token_ids,
101
+ 'attention_mask': attention_mask,
102
+ 'labels': labels,
103
+ 'sents': sents,
104
+ 'sent_ids': sent_ids
105
+ }
106
 
107
  return batched_data
108
 
 
111
  def __init__(self, dataset, args):
112
  self.dataset = dataset
113
  self.p = args
 
114
 
115
  def __len__(self):
116
  return len(self.dataset)
 
122
  sents = [x[0] for x in data]
123
  sent_ids = [x[1] for x in data]
124
 
125
+ encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
126
  token_ids = torch.LongTensor(encoding['input_ids'])
127
  attention_mask = torch.LongTensor(encoding['attention_mask'])
128
 
 
132
  token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
133
 
134
  batched_data = {
135
+ 'token_ids': token_ids,
136
+ 'attention_mask': attention_mask,
137
+ 'sents': sents,
138
+ 'sent_ids': sent_ids
139
+ }
140
 
141
  return batched_data
142
 
143
 
144
  # Load the data: a list of (sentence, label).
145
  def load_data(filename, flag='train'):
146
+ num_labels = set()
147
  data = []
148
+ with open(filename, 'r') as fp:
149
+ for record in csv.DictReader(fp, delimiter = '\t'):
150
+ if flag == 'test':
151
  sent = record['sentence'].lower().strip()
152
  sent_id = record['id'].lower().strip()
153
  data.append((sent,sent_id))
154
+ else:
 
 
155
  sent = record['sentence'].lower().strip()
156
  sent_id = record['id'].lower().strip()
157
  label = int(record['sentiment'].strip())
158
+ num_labels.add(label)
159
+ data.append((sent, label, sent_id))
 
160
  print(f"load {len(data)} data from {filename}")
161
 
162
  if flag == 'train':
 
244
  dev_dataset = SentimentDataset(dev_data, args)
245
 
246
  train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
247
+ num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
248
  dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
249
+ num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
250
 
251
  # Init model.
252
  config = {'hidden_dropout_prob': args.hidden_dropout_prob,
 
302
  def test(args):
303
  with torch.no_grad():
304
  device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
305
+ saved = torch.load(args.filepath, weights_only=False)
306
  config = saved['model_config']
307
  model = BertSentimentClassifier(config)
308
  model.load_state_dict(saved['model'])
 
311
 
312
  dev_data = load_data(args.dev, 'valid')
313
  dev_dataset = SentimentDataset(dev_data, args)
314
+ dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
315
+ num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
316
 
 
 
 
 
317
  dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
318
  print('DONE DEV')
319
+ print(f"dev acc :: {dev_acc :.3f}")
320
+
321
+ # ---- SKIP RUNNING ON TEST DATASET ---- #
322
+ # test_data = load_data(args.test, 'test')
323
+ # test_dataset = SentimentTestDataset(test_data, args)
324
+ # test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size,
325
+ # num_workers=args.num_cpu_cores, collate_fn=test_dataset.collate_fn)
326
+ # test_pred, test_sents, test_sent_ids = model_test_eval(test_dataloader, model, device)
327
+ # print('DONE TEST')
328
+
329
+ # ---- SKIP SAVING PREDICTIONS ----
330
+ # with open(args.dev_out, "w+") as f:
331
+ # f.write(f"id \t Predicted_Sentiment \n")
332
+ # for p, s in zip(dev_sent_ids,dev_pred):
333
+ # f.write(f"{p} , {s} \n")
334
+ # with open(args.test_out, "w+") as f:
335
+ # f.write(f"id \t Predicted_Sentiment \n")
336
+ # for p, s in zip(test_sent_ids,test_pred ):
337
+ # f.write(f"{p} , {s} \n")
338
 
339
 
340
  def get_args():
341
  parser = argparse.ArgumentParser()
342
  parser.add_argument("--seed", type=int, default=11711)
343
+ parser.add_argument("--num-cpu-cores", type=int, default=4)
344
  parser.add_argument("--epochs", type=int, default=10)
345
  parser.add_argument("--fine-tune-mode", type=str,
346
  help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
347
  choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
348
  parser.add_argument("--use_gpu", action='store_true')
349
 
350
+ parser.add_argument("--batch_size_sst", help='64 can fit a 12GB GPU', type=int, default=8)
351
+ parser.add_argument("--batch_size_cfimdb", help='8 can fit a 12GB GPU', type=int, default=8)
352
  parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
353
  parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
354
  default=1e-3)
 
357
  return args
358
 
359
 
360
+ def main():
361
  args = get_args()
362
  seed_everything(args.seed)
363
+ torch.set_num_threads(args.num_cpu_cores)
364
+
365
+ print(torch.get_num_threads())
366
 
367
  print('Training Sentiment Classifier on SST...')
368
  config = SimpleNamespace(
369
  filepath='sst-classifier.pt',
370
  lr=args.lr,
371
+ num_cpu_cores=args.num_cpu_cores,
372
  use_gpu=args.use_gpu,
373
  epochs=args.epochs,
374
+ batch_size=args.batch_size_sst,
375
  hidden_dropout_prob=args.hidden_dropout_prob,
376
  train='data/ids-sst-train.csv',
377
  dev='data/ids-sst-dev.csv',
 
390
  config = SimpleNamespace(
391
  filepath='cfimdb-classifier.pt',
392
  lr=args.lr,
393
+ num_cpu_cores=args.num_cpu_cores,
394
  use_gpu=args.use_gpu,
395
  epochs=args.epochs,
396
+ batch_size=args.batch_size_cfimdb,
397
  hidden_dropout_prob=args.hidden_dropout_prob,
398
  train='data/ids-cfimdb-train.csv',
399
  dev='data/ids-cfimdb-dev.csv',
 
407
 
408
  print('Evaluating on cfimdb...')
409
  test(config)
410
+
411
+
412
+ if __name__ == "__main__":
413
+ main()
data/{sts-test-student.csv → nli-dev.parquet} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dee455745b72e9ca3ff74e7c056bd73e34bad5b8d5641045a2c1e7e131866f47
3
- size 256677
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c267496435885e724abc71e53669fae59db875bfa13389eab8f9b0b2dfb2b32e
3
+ size 782233
data/{sts-train.csv → nli-test.parquet} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:15d12efc2d656fffb1d61ac1f08ec4227f43925fd16f420c037cbd063699c21b
3
- size 928832
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01688df43ae4c019a86144a0d2351146b124688a55f285071cccd156225a5fdf
3
+ size 810423
data/{quora-test-student.csv → nli-train.parquet} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4fa130f532cdde70287081aa04af13a4b12e3aa862e9162763d15fb46385497a
3
- size 13487951
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9aeca80b1bda983ee316f854ebc37af8341877fb932dd6a2c6aba978ad112a5
3
+ size 38396324
data/{sts-dev.csv → stsb-dev.parquet} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce3cad6f16062586ac7ba462c28b010a9be10c530fd5074165860d7b7ab4e93d
3
- size 132265
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c6e0e9881f1b398abe3e439a482f4686305c3784568c462f6bba58bdff03b0a
3
+ size 142187
data/{quora-dev.csv → stsb-test.parquet} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e9dc46b273a711d82a065f55e1754a9b92c10ad7345ebe0b0ebba61397dda4a
3
- size 6896912
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8acbc291c50977d8655934952956016c3e049c2fe04f8a6c454c1bf6acc42ca1
3
+ size 108100
data/stsb-train.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eae324ff1eac2d0ba769851736eb7232eda64f370a16eb20e74a2c5f8f5fafe0
3
+ size 470612
data/{quora-train.csv → twitter-unsup.csv} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7cd59e1ddb3a5b5d03f4a885c64e67aaf50122d9ab9ed7a476b5d2d6f7137ae8
3
- size 48270674
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a7af1ec5fc749ec8e5ea13c574aeb5c06254aa1c081e3421868079d5356b3f4
3
+ size 20895533
datasets.py DELETED
@@ -1,272 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- '''
4
- This module contains our Dataset classes and functions that load the three datasets
5
- for training and evaluating multitask BERT.
6
-
7
- Feel free to edit code in this file if you wish to modify the way in which the data
8
- examples are preprocessed.
9
- '''
10
-
11
- import csv
12
-
13
- import torch
14
- from torch.utils.data import Dataset
15
- from tokenizer import BertTokenizer
16
-
17
-
18
- def preprocess_string(s):
19
- return ' '.join(s.lower()
20
- .replace('.', ' .')
21
- .replace('?', ' ?')
22
- .replace(',', ' ,')
23
- .replace('\'', ' \'')
24
- .split())
25
-
26
-
27
- class SentenceClassificationDataset(Dataset):
28
- def __init__(self, dataset, args):
29
- self.dataset = dataset
30
- self.p = args
31
- self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
32
-
33
- def __len__(self):
34
- return len(self.dataset)
35
-
36
- def __getitem__(self, idx):
37
- return self.dataset[idx]
38
-
39
- def pad_data(self, data):
40
-
41
- sents = [x[0] for x in data]
42
- labels = [x[1] for x in data]
43
- sent_ids = [x[2] for x in data]
44
-
45
- encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
46
- token_ids = torch.LongTensor(encoding['input_ids'])
47
- attention_mask = torch.LongTensor(encoding['attention_mask'])
48
- labels = torch.LongTensor(labels)
49
-
50
- return token_ids, attention_mask, labels, sents, sent_ids
51
-
52
- def collate_fn(self, all_data):
53
- token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
54
-
55
- batched_data = {
56
- 'token_ids': token_ids,
57
- 'attention_mask': attention_mask,
58
- 'labels': labels,
59
- 'sents': sents,
60
- 'sent_ids': sent_ids
61
- }
62
-
63
- return batched_data
64
-
65
-
66
- # Unlike SentenceClassificationDataset, we do not load labels in SentenceClassificationTestDataset.
67
- class SentenceClassificationTestDataset(Dataset):
68
- def __init__(self, dataset, args):
69
- self.dataset = dataset
70
- self.p = args
71
- self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
72
-
73
- def __len__(self):
74
- return len(self.dataset)
75
-
76
- def __getitem__(self, idx):
77
- return self.dataset[idx]
78
-
79
- def pad_data(self, data):
80
- sents = [x[0] for x in data]
81
- sent_ids = [x[1] for x in data]
82
-
83
- encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
84
- token_ids = torch.LongTensor(encoding['input_ids'])
85
- attention_mask = torch.LongTensor(encoding['attention_mask'])
86
-
87
- return token_ids, attention_mask, sents, sent_ids
88
-
89
- def collate_fn(self, all_data):
90
- token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
91
-
92
- batched_data = {
93
- 'token_ids': token_ids,
94
- 'attention_mask': attention_mask,
95
- 'sents': sents,
96
- 'sent_ids': sent_ids
97
- }
98
-
99
- return batched_data
100
-
101
-
102
- class SentencePairDataset(Dataset):
103
- def __init__(self, dataset, args, isRegression=False):
104
- self.dataset = dataset
105
- self.p = args
106
- self.isRegression = isRegression
107
- self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
108
-
109
- def __len__(self):
110
- return len(self.dataset)
111
-
112
- def __getitem__(self, idx):
113
- return self.dataset[idx]
114
-
115
- def pad_data(self, data):
116
- sent1 = [x[0] for x in data]
117
- sent2 = [x[1] for x in data]
118
- labels = [x[2] for x in data]
119
- sent_ids = [x[3] for x in data]
120
-
121
- encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
122
- encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)
123
-
124
- token_ids = torch.LongTensor(encoding1['input_ids'])
125
- attention_mask = torch.LongTensor(encoding1['attention_mask'])
126
- token_type_ids = torch.LongTensor(encoding1['token_type_ids'])
127
-
128
- token_ids2 = torch.LongTensor(encoding2['input_ids'])
129
- attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
130
- token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
131
- if self.isRegression:
132
- labels = torch.DoubleTensor(labels)
133
- else:
134
- labels = torch.LongTensor(labels)
135
-
136
- return (token_ids, token_type_ids, attention_mask,
137
- token_ids2, token_type_ids2, attention_mask2,
138
- labels,sent_ids)
139
-
140
- def collate_fn(self, all_data):
141
- (token_ids, token_type_ids, attention_mask,
142
- token_ids2, token_type_ids2, attention_mask2,
143
- labels, sent_ids) = self.pad_data(all_data)
144
-
145
- batched_data = {
146
- 'token_ids_1': token_ids,
147
- 'token_type_ids_1': token_type_ids,
148
- 'attention_mask_1': attention_mask,
149
- 'token_ids_2': token_ids2,
150
- 'token_type_ids_2': token_type_ids2,
151
- 'attention_mask_2': attention_mask2,
152
- 'labels': labels,
153
- 'sent_ids': sent_ids
154
- }
155
-
156
- return batched_data
157
-
158
-
159
- # Unlike SentencePairDataset, we do not load labels in SentencePairTestDataset.
160
- class SentencePairTestDataset(Dataset):
161
- def __init__(self, dataset, args):
162
- self.dataset = dataset
163
- self.p = args
164
- self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
165
-
166
- def __len__(self):
167
- return len(self.dataset)
168
-
169
- def __getitem__(self, idx):
170
- return self.dataset[idx]
171
-
172
- def pad_data(self, data):
173
- sent1 = [x[0] for x in data]
174
- sent2 = [x[1] for x in data]
175
- sent_ids = [x[2] for x in data]
176
-
177
- encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
178
- encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)
179
-
180
- token_ids = torch.LongTensor(encoding1['input_ids'])
181
- attention_mask = torch.LongTensor(encoding1['attention_mask'])
182
- token_type_ids = torch.LongTensor(encoding1['token_type_ids'])
183
-
184
- token_ids2 = torch.LongTensor(encoding2['input_ids'])
185
- attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
186
- token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
187
-
188
-
189
- return (token_ids, token_type_ids, attention_mask,
190
- token_ids2, token_type_ids2, attention_mask2,
191
- sent_ids)
192
-
193
- def collate_fn(self, all_data):
194
- (token_ids, token_type_ids, attention_mask,
195
- token_ids2, token_type_ids2, attention_mask2,
196
- sent_ids) = self.pad_data(all_data)
197
-
198
- batched_data = {
199
- 'token_ids_1': token_ids,
200
- 'token_type_ids_1': token_type_ids,
201
- 'attention_mask_1': attention_mask,
202
- 'token_ids_2': token_ids2,
203
- 'token_type_ids_2': token_type_ids2,
204
- 'attention_mask_2': attention_mask2,
205
- 'sent_ids': sent_ids
206
- }
207
-
208
- return batched_data
209
-
210
-
211
- def load_multitask_data(sentiment_filename,paraphrase_filename,similarity_filename,split='train'):
212
- sentiment_data = []
213
- num_labels = {}
214
- if split == 'test':
215
- with open(sentiment_filename, 'r') as fp:
216
- for record in csv.DictReader(fp,delimiter = '\t'):
217
- sent = record['sentence'].lower().strip()
218
- sent_id = record['id'].lower().strip()
219
- sentiment_data.append((sent,sent_id))
220
- else:
221
- with open(sentiment_filename, 'r') as fp:
222
- for record in csv.DictReader(fp,delimiter = '\t'):
223
- sent = record['sentence'].lower().strip()
224
- sent_id = record['id'].lower().strip()
225
- label = int(record['sentiment'].strip())
226
- if label not in num_labels:
227
- num_labels[label] = len(num_labels)
228
- sentiment_data.append((sent, label,sent_id))
229
-
230
- print(f"Loaded {len(sentiment_data)} {split} examples from {sentiment_filename}")
231
-
232
- paraphrase_data = []
233
- if split == 'test':
234
- with open(paraphrase_filename, 'r') as fp:
235
- for record in csv.DictReader(fp,delimiter = '\t'):
236
- sent_id = record['id'].lower().strip()
237
- paraphrase_data.append((preprocess_string(record['sentence1']),
238
- preprocess_string(record['sentence2']),
239
- sent_id))
240
-
241
- else:
242
- with open(paraphrase_filename, 'r') as fp:
243
- for record in csv.DictReader(fp,delimiter = '\t'):
244
- try:
245
- sent_id = record['id'].lower().strip()
246
- paraphrase_data.append((preprocess_string(record['sentence1']),
247
- preprocess_string(record['sentence2']),
248
- int(float(record['is_duplicate'])),sent_id))
249
- except:
250
- pass
251
-
252
- print(f"Loaded {len(paraphrase_data)} {split} examples from {paraphrase_filename}")
253
-
254
- similarity_data = []
255
- if split == 'test':
256
- with open(similarity_filename, 'r') as fp:
257
- for record in csv.DictReader(fp,delimiter = '\t'):
258
- sent_id = record['id'].lower().strip()
259
- similarity_data.append((preprocess_string(record['sentence1']),
260
- preprocess_string(record['sentence2'])
261
- ,sent_id))
262
- else:
263
- with open(similarity_filename, 'r') as fp:
264
- for record in csv.DictReader(fp,delimiter = '\t'):
265
- sent_id = record['id'].lower().strip()
266
- similarity_data.append((preprocess_string(record['sentence1']),
267
- preprocess_string(record['sentence2']),
268
- float(record['similarity']),sent_id))
269
-
270
- print(f"Loaded {len(similarity_data)} {split} examples from {similarity_filename}")
271
-
272
- return sentiment_data, num_labels, paraphrase_data, similarity_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
justfile ADDED
@@ -0,0 +1 @@
 
 
1
+ python classifier.py --num-cpu-cores 8 --batch_size_sst 64 --batch_size_cfimdb 8
multitask_classifier.py DELETED
@@ -1,340 +0,0 @@
1
- '''
2
- Multitask BERT class, starter training code, evaluation, and test code.
3
-
4
- Of note are:
5
- * class MultitaskBERT: Your implementation of multitask BERT.
6
- * function train_multitask: Training procedure for MultitaskBERT. Starter code
7
- copies training procedure from `classifier.py` (single-task SST).
8
- * function test_multitask: Test procedure for MultitaskBERT. This function generates
9
- the required files for submission.
10
-
11
- Running `python multitask_classifier.py` trains and tests your MultitaskBERT and
12
- writes all required submission files.
13
- '''
14
-
15
- import random, numpy as np, argparse
16
- from types import SimpleNamespace
17
-
18
- import torch
19
- from torch import nn
20
- import torch.nn.functional as F
21
- from torch.utils.data import DataLoader
22
-
23
- from bert import BertModel
24
- from optimizer import AdamW
25
- from tqdm import tqdm
26
-
27
- from datasets import (
28
- SentenceClassificationDataset,
29
- SentenceClassificationTestDataset,
30
- SentencePairDataset,
31
- SentencePairTestDataset,
32
- load_multitask_data
33
- )
34
-
35
- from evaluation import model_eval_sst, model_eval_multitask, model_eval_test_multitask
36
-
37
-
38
- TQDM_DISABLE=False
39
-
40
-
41
- # Fix the random seed.
42
- def seed_everything(seed=11711):
43
- random.seed(seed)
44
- np.random.seed(seed)
45
- torch.manual_seed(seed)
46
- torch.cuda.manual_seed(seed)
47
- torch.cuda.manual_seed_all(seed)
48
- torch.backends.cudnn.benchmark = False
49
- torch.backends.cudnn.deterministic = True
50
-
51
-
52
- BERT_HIDDEN_SIZE = 768
53
- N_SENTIMENT_CLASSES = 5
54
-
55
-
56
- class MultitaskBERT(nn.Module):
57
- '''
58
- This module should use BERT for 3 tasks:
59
-
60
- - Sentiment classification (predict_sentiment)
61
- - Paraphrase detection (predict_paraphrase)
62
- - Semantic Textual Similarity (predict_similarity)
63
- '''
64
- def __init__(self, config):
65
- super(MultitaskBERT, self).__init__()
66
- self.bert = BertModel.from_pretrained('bert-base-uncased')
67
- # last-linear-layer mode does not require updating BERT paramters.
68
- assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
69
- for param in self.bert.parameters():
70
- if config.fine_tune_mode == 'last-linear-layer':
71
- param.requires_grad = False
72
- elif config.fine_tune_mode == 'full-model':
73
- param.requires_grad = True
74
- # You will want to add layers here to perform the downstream tasks.
75
- ### TODO
76
- raise NotImplementedError
77
-
78
-
79
- def forward(self, input_ids, attention_mask):
80
- 'Takes a batch of sentences and produces embeddings for them.'
81
- # The final BERT embedding is the hidden state of [CLS] token (the first token)
82
- # Here, you can start by just returning the embeddings straight from BERT.
83
- # When thinking of improvements, you can later try modifying this
84
- # (e.g., by adding other layers).
85
- ### TODO
86
- raise NotImplementedError
87
-
88
-
89
- def predict_sentiment(self, input_ids, attention_mask):
90
- '''Given a batch of sentences, outputs logits for classifying sentiment.
91
- There are 5 sentiment classes:
92
- (0 - negative, 1- somewhat negative, 2- neutral, 3- somewhat positive, 4- positive)
93
- Thus, your output should contain 5 logits for each sentence.
94
- '''
95
- ### TODO
96
- raise NotImplementedError
97
-
98
-
99
- def predict_paraphrase(self,
100
- input_ids_1, attention_mask_1,
101
- input_ids_2, attention_mask_2):
102
- '''Given a batch of pairs of sentences, outputs a single logit for predicting whether they are paraphrases.
103
- Note that your output should be unnormalized (a logit); it will be passed to the sigmoid function
104
- during evaluation.
105
- '''
106
- ### TODO
107
- raise NotImplementedError
108
-
109
-
110
- def predict_similarity(self,
111
- input_ids_1, attention_mask_1,
112
- input_ids_2, attention_mask_2):
113
- '''Given a batch of pairs of sentences, outputs a single logit corresponding to how similar they are.
114
- Note that your output should be unnormalized (a logit).
115
- '''
116
- ### TODO
117
- raise NotImplementedError
118
-
119
-
120
-
121
-
122
- def save_model(model, optimizer, args, config, filepath):
123
- save_info = {
124
- 'model': model.state_dict(),
125
- 'optim': optimizer.state_dict(),
126
- 'args': args,
127
- 'model_config': config,
128
- 'system_rng': random.getstate(),
129
- 'numpy_rng': np.random.get_state(),
130
- 'torch_rng': torch.random.get_rng_state(),
131
- }
132
-
133
- torch.save(save_info, filepath)
134
- print(f"save the model to {filepath}")
135
-
136
-
137
- def train_multitask(args):
138
- '''Train MultitaskBERT.
139
-
140
- Currently only trains on SST dataset. The way you incorporate training examples
141
- from other datasets into the training procedure is up to you. To begin, take a
142
- look at test_multitask below to see how you can use the custom torch `Dataset`s
143
- in datasets.py to load in examples from the Quora and SemEval datasets.
144
- '''
145
- device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
146
- # Create the data and its corresponding datasets and dataloader.
147
- sst_train_data, num_labels,para_train_data, sts_train_data = load_multitask_data(args.sst_train,args.para_train,args.sts_train, split ='train')
148
- sst_dev_data, num_labels,para_dev_data, sts_dev_data = load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev, split ='train')
149
-
150
- sst_train_data = SentenceClassificationDataset(sst_train_data, args)
151
- sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)
152
-
153
- sst_train_dataloader = DataLoader(sst_train_data, shuffle=True, batch_size=args.batch_size,
154
- collate_fn=sst_train_data.collate_fn)
155
- sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
156
- collate_fn=sst_dev_data.collate_fn)
157
-
158
- # Init model.
159
- config = {'hidden_dropout_prob': args.hidden_dropout_prob,
160
- 'num_labels': num_labels,
161
- 'hidden_size': 768,
162
- 'data_dir': '.',
163
- 'fine_tune_mode': args.fine_tune_mode}
164
-
165
- config = SimpleNamespace(**config)
166
-
167
- model = MultitaskBERT(config)
168
- model = model.to(device)
169
-
170
- lr = args.lr
171
- optimizer = AdamW(model.parameters(), lr=lr)
172
- best_dev_acc = 0
173
-
174
- # Run for the specified number of epochs.
175
- for epoch in range(args.epochs):
176
- model.train()
177
- train_loss = 0
178
- num_batches = 0
179
- for batch in tqdm(sst_train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):
180
- b_ids, b_mask, b_labels = (batch['token_ids'],
181
- batch['attention_mask'], batch['labels'])
182
-
183
- b_ids = b_ids.to(device)
184
- b_mask = b_mask.to(device)
185
- b_labels = b_labels.to(device)
186
-
187
- optimizer.zero_grad()
188
- logits = model.predict_sentiment(b_ids, b_mask)
189
- loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
190
-
191
- loss.backward()
192
- optimizer.step()
193
-
194
- train_loss += loss.item()
195
- num_batches += 1
196
-
197
- train_loss = train_loss / (num_batches)
198
-
199
- train_acc, train_f1, *_ = model_eval_sst(sst_train_dataloader, model, device)
200
- dev_acc, dev_f1, *_ = model_eval_sst(sst_dev_dataloader, model, device)
201
-
202
- if dev_acc > best_dev_acc:
203
- best_dev_acc = dev_acc
204
- save_model(model, optimizer, args, config, args.filepath)
205
-
206
- print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
207
-
208
-
209
- def test_multitask(args):
210
- '''Test and save predictions on the dev and test sets of all three tasks.'''
211
- with torch.no_grad():
212
- device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
213
- saved = torch.load(args.filepath)
214
- config = saved['model_config']
215
-
216
- model = MultitaskBERT(config)
217
- model.load_state_dict(saved['model'])
218
- model = model.to(device)
219
- print(f"Loaded model to test from {args.filepath}")
220
-
221
- sst_test_data, num_labels,para_test_data, sts_test_data = \
222
- load_multitask_data(args.sst_test,args.para_test, args.sts_test, split='test')
223
-
224
- sst_dev_data, num_labels,para_dev_data, sts_dev_data = \
225
- load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev,split='dev')
226
-
227
- sst_test_data = SentenceClassificationTestDataset(sst_test_data, args)
228
- sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)
229
-
230
- sst_test_dataloader = DataLoader(sst_test_data, shuffle=True, batch_size=args.batch_size,
231
- collate_fn=sst_test_data.collate_fn)
232
- sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
233
- collate_fn=sst_dev_data.collate_fn)
234
-
235
- para_test_data = SentencePairTestDataset(para_test_data, args)
236
- para_dev_data = SentencePairDataset(para_dev_data, args)
237
-
238
- para_test_dataloader = DataLoader(para_test_data, shuffle=True, batch_size=args.batch_size,
239
- collate_fn=para_test_data.collate_fn)
240
- para_dev_dataloader = DataLoader(para_dev_data, shuffle=False, batch_size=args.batch_size,
241
- collate_fn=para_dev_data.collate_fn)
242
-
243
- sts_test_data = SentencePairTestDataset(sts_test_data, args)
244
- sts_dev_data = SentencePairDataset(sts_dev_data, args, isRegression=True)
245
-
246
- sts_test_dataloader = DataLoader(sts_test_data, shuffle=True, batch_size=args.batch_size,
247
- collate_fn=sts_test_data.collate_fn)
248
- sts_dev_dataloader = DataLoader(sts_dev_data, shuffle=False, batch_size=args.batch_size,
249
- collate_fn=sts_dev_data.collate_fn)
250
-
251
- dev_sentiment_accuracy,dev_sst_y_pred, dev_sst_sent_ids, \
252
- dev_paraphrase_accuracy, dev_para_y_pred, dev_para_sent_ids, \
253
- dev_sts_corr, dev_sts_y_pred, dev_sts_sent_ids = model_eval_multitask(sst_dev_dataloader,
254
- para_dev_dataloader,
255
- sts_dev_dataloader, model, device)
256
-
257
- test_sst_y_pred, \
258
- test_sst_sent_ids, test_para_y_pred, test_para_sent_ids, test_sts_y_pred, test_sts_sent_ids = \
259
- model_eval_test_multitask(sst_test_dataloader,
260
- para_test_dataloader,
261
- sts_test_dataloader, model, device)
262
-
263
- with open(args.sst_dev_out, "w+") as f:
264
- print(f"dev sentiment acc :: {dev_sentiment_accuracy :.3f}")
265
- f.write(f"id \t Predicted_Sentiment \n")
266
- for p, s in zip(dev_sst_sent_ids, dev_sst_y_pred):
267
- f.write(f"{p} , {s} \n")
268
-
269
- with open(args.sst_test_out, "w+") as f:
270
- f.write(f"id \t Predicted_Sentiment \n")
271
- for p, s in zip(test_sst_sent_ids, test_sst_y_pred):
272
- f.write(f"{p} , {s} \n")
273
-
274
- with open(args.para_dev_out, "w+") as f:
275
- print(f"dev paraphrase acc :: {dev_paraphrase_accuracy :.3f}")
276
- f.write(f"id \t Predicted_Is_Paraphrase \n")
277
- for p, s in zip(dev_para_sent_ids, dev_para_y_pred):
278
- f.write(f"{p} , {s} \n")
279
-
280
- with open(args.para_test_out, "w+") as f:
281
- f.write(f"id \t Predicted_Is_Paraphrase \n")
282
- for p, s in zip(test_para_sent_ids, test_para_y_pred):
283
- f.write(f"{p} , {s} \n")
284
-
285
- with open(args.sts_dev_out, "w+") as f:
286
- print(f"dev sts corr :: {dev_sts_corr :.3f}")
287
- f.write(f"id \t Predicted_Similiary \n")
288
- for p, s in zip(dev_sts_sent_ids, dev_sts_y_pred):
289
- f.write(f"{p} , {s} \n")
290
-
291
- with open(args.sts_test_out, "w+") as f:
292
- f.write(f"id \t Predicted_Similiary \n")
293
- for p, s in zip(test_sts_sent_ids, test_sts_y_pred):
294
- f.write(f"{p} , {s} \n")
295
-
296
-
297
- def get_args():
298
- parser = argparse.ArgumentParser()
299
- parser.add_argument("--sst_train", type=str, default="data/ids-sst-train.csv")
300
- parser.add_argument("--sst_dev", type=str, default="data/ids-sst-dev.csv")
301
- parser.add_argument("--sst_test", type=str, default="data/ids-sst-test-student.csv")
302
-
303
- parser.add_argument("--para_train", type=str, default="data/quora-train.csv")
304
- parser.add_argument("--para_dev", type=str, default="data/quora-dev.csv")
305
- parser.add_argument("--para_test", type=str, default="data/quora-test-student.csv")
306
-
307
- parser.add_argument("--sts_train", type=str, default="data/sts-train.csv")
308
- parser.add_argument("--sts_dev", type=str, default="data/sts-dev.csv")
309
- parser.add_argument("--sts_test", type=str, default="data/sts-test-student.csv")
310
-
311
- parser.add_argument("--seed", type=int, default=11711)
312
- parser.add_argument("--epochs", type=int, default=10)
313
- parser.add_argument("--fine-tune-mode", type=str,
314
- help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
315
- choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
316
- parser.add_argument("--use_gpu", action='store_true')
317
-
318
- parser.add_argument("--sst_dev_out", type=str, default="predictions/sst-dev-output.csv")
319
- parser.add_argument("--sst_test_out", type=str, default="predictions/sst-test-output.csv")
320
-
321
- parser.add_argument("--para_dev_out", type=str, default="predictions/para-dev-output.csv")
322
- parser.add_argument("--para_test_out", type=str, default="predictions/para-test-output.csv")
323
-
324
- parser.add_argument("--sts_dev_out", type=str, default="predictions/sts-dev-output.csv")
325
- parser.add_argument("--sts_test_out", type=str, default="predictions/sts-test-output.csv")
326
-
327
- parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)
328
- parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
329
- parser.add_argument("--lr", type=float, help="learning rate", default=1e-5)
330
-
331
- args = parser.parse_args()
332
- return args
333
-
334
-
335
- if __name__ == "__main__":
336
- args = get_args()
337
- args.filepath = f'{args.fine_tune_mode}-{args.epochs}-{args.lr}-multitask.pt' # Save path.
338
- seed_everything(args.seed) # Fix the seed for reproducibility.
339
- train_multitask(args)
340
- test_multitask(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Tôi muốn finetune minBERT bằng phương pháp Unsupervised SimCSE để thực hiện sentiment analysis nhưng chưa biết phải làm như thế nào. theo như tôi hiểu thì tôi sẽ finetune mô hình minBERT bằng SimCSE để có được embeddings tốt hơn, sau đó sẽ dùng embeddings này để truyền qua SentimentClassifier để phân loại. Tuy nhiên, hướng tiếp cận đúng đắn là gì?
2
+ Tôi đã nghĩ đến hai cách sau đây (hoặc có thể cách khác nhưng chưa nghĩ ra). Bạn xem xét thử nhé!
3
+ 1. Finetune minBERT bằng SimCSE trước rồi mới finetune SentimentClassifier: sử dụng dataset STS-B hoặc Twitter Sentiment Dataset để finetune minBERT, rồi đánh giá độ
trainings/last-layer-w-dropout.txt CHANGED
@@ -1,4 +1,4 @@
1
- Training Sentiment Classifier on SST...
2
  load 8544 data from data/ids-sst-train.csv
3
  load 1101 data from data/ids-sst-dev.csv
4
  Epoch 0: train loss :: 1.458, train acc :: 0.460, dev acc :: 0.442
@@ -14,7 +14,7 @@ Epoch 9: train loss :: 1.227, train acc :: 0.509, dev acc :: 0.475
14
  Evaluating on SST...
15
  load model from sst-classifier.pt
16
  load 1101 data from data/ids-sst-dev.csv
17
- DONE DEV
18
  DONE Test
19
  dev acc :: 0.475
20
  Training Sentiment Classifier on cfimdb...
@@ -33,6 +33,6 @@ Epoch 9: train loss :: 0.407, train acc :: 0.895, dev acc :: 0.873
33
  Evaluating on cfimdb...
34
  load model from cfimdb-classifier.pt
35
  load 245 data from data/ids-cfimdb-dev.csv
36
- DONE DEV
37
- DONE Test
38
  dev acc :: 0.873
 
1
+ Training Sentiment Classifier on SST...
2
  load 8544 data from data/ids-sst-train.csv
3
  load 1101 data from data/ids-sst-dev.csv
4
  Epoch 0: train loss :: 1.458, train acc :: 0.460, dev acc :: 0.442
 
14
  Evaluating on SST...
15
  load model from sst-classifier.pt
16
  load 1101 data from data/ids-sst-dev.csv
17
+ DONE DEV
18
  DONE Test
19
  dev acc :: 0.475
20
  Training Sentiment Classifier on cfimdb...
 
33
  Evaluating on cfimdb...
34
  load model from cfimdb-classifier.pt
35
  load 245 data from data/ids-cfimdb-dev.csv
36
+ DONE DEV
37
+ DONE Test
38
  dev acc :: 0.873
unsup_simcse.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import torch
3
+ import random
4
+ import argparse
5
+ import numpy as np
6
+
7
+ from tqdm import tqdm
8
+ from types import SimpleNamespace
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from sklearn.metrics import f1_score, accuracy_score
11
+
12
+ from bert import BertModel
13
+ from optimizer import AdamW
14
+ from classifier import seed_everything, tokenizer
15
+ from classifier import SentimentDataset, BertSentimentClassifier
16
+
17
+
18
+ TQDM_DISABLE = False
19
+
20
+
21
+ class TwitterDataset(Dataset):
22
+ def __init__(self, dataset, args):
23
+ self.dataset = dataset
24
+ self.p = args
25
+
26
+ def __len__(self):
27
+ return len(self.dataset)
28
+
29
+ def __getitem__(self, idx):
30
+ return self.dataset[idx]
31
+
32
+ def pad_data(self, sents):
33
+ encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
34
+ token_ids = torch.LongTensor(encoding['input_ids'])
35
+ attension_mask = torch.LongTensor(encoding['attention_mask'])
36
+
37
+ return token_ids, attension_mask
38
+
39
+ def collate_fn(self, sents):
40
+ token_ids, attention_mask = self.pad_data(sents)
41
+
42
+ batched_data = {
43
+ 'token_ids': token_ids,
44
+ 'attention_mask': attention_mask,
45
+ }
46
+
47
+ return batched_data
48
+
49
+
50
+ def load_data(filename, flag='train'):
51
+ '''
52
+ - for Twitter dataset: list of sentences
53
+ - for SST/CFIMDB dataset: list of (sent, [label], sent_id)
54
+ '''
55
+ num_labels = set()
56
+ data = []
57
+ with open(filename, 'r') as fp:
58
+ for record in csv.DictReader(fp, delimiter = ',', ):
59
+ if flag == 'twitter':
60
+ sent = record['clean_text'].lower().strip()
61
+ data.append(sent)
62
+ elif flag == 'test':
63
+ sent = record['sentence'].lower().strip()
64
+ sent_id = record['id'].lower().strip()
65
+ data.append((sent,sent_id))
66
+ else:
67
+ sent = record['sentence'].lower().strip()
68
+ sent_id = record['id'].lower().strip()
69
+ label = int(record['sentiment'].strip())
70
+ num_labels.add(label)
71
+ data.append((sent, label, sent_id))
72
+ print(f"load {len(data)} data from {filename}")
73
+
74
+ if flag == 'train':
75
+ return data, len(num_labels)
76
+ else:
77
+ return data
78
+
79
+
80
+ def save_model(model, optimizer, args, config, filepath):
81
+ save_info = {
82
+ 'model': model.state_dict(),
83
+ 'optim': optimizer.state_dict(),
84
+ 'args': args,
85
+ 'model_config': config,
86
+ 'system_rng': random.getstate(),
87
+ 'numpy_rng': np.random.get_state(),
88
+ 'torch_rng': torch.random.get_rng_state(),
89
+ }
90
+
91
+ torch.save(save_info, filepath)
92
+ print(f"save the model to {filepath}")
93
+
94
+
95
+ def train(args):
96
+ '''
97
+ Training Pipeline
98
+ -----------------
99
+ 1. Load the Twitter Sentiment and SST Dataset.
100
+ 2. Determine batch_size (64) and number of batches (?).
101
+ 3. Initialize SentimentClassifier (including bert).
102
+ 4. Looping through 10 epoches.
103
+ 5. Finetune minBERT with SimCSE loss function.
104
+ 6. Finetune Classifier with cross-entropy function.
105
+ 7. Backpropagation using Adam Optimizer for both.
106
+ 8. Evaluating the model on dev dataset.
107
+ 9. If dev_acc > best_dev_acc: save_model(...)
108
+ '''
109
+
110
+ twitter_data = load_data(args.train_bert, 'twitter')
111
+ train_data, num_labels = load_data(args.train, 'train')
112
+ dev_data = load_data(args.dev, 'valid')
113
+
114
+ twitter_dataset = TwitterDataset(twitter_data, args)
115
+ train_dataset = SentimentDataset(train_data, args)
116
+ dev_dataset = SentimentDataset(dev_data, args)
117
+
118
+ twitter_dataloader = DataLoader(twitter_dataset, shuffle=True, batch_size=args.batch_size_cse,
119
+ num_workers=args.num_cpu_cores, collate_fn=twitter_dataset.collate_fn)
120
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
121
+ num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
122
+ dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
123
+ num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
124
+
125
+ config = SimpleNamespace(
126
+ hidden_dropout_prob=args.hidden_dropout_prob,
127
+ num_labels=num_labels,
128
+ hidden_size=768,
129
+ data_dir='.',
130
+ fine_tune_mode='full-model'
131
+ )
132
+
133
+ device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
134
+ model = BertSentimentClassifier(config)
135
+ model = model.to(device)
136
+
137
+ optimizer_cse = AdamW(model.bert.parameters(), lr=args.lr_cse)
138
+ optimizer_classifier = AdamW(model.parameters(), lr=args.lr_classifier)
139
+ best_dev_acc = 0
140
+
141
+ for epoch in range(args.epochs):
142
+ model.bert.train()
143
+ train_loss = num_batches = 0
144
+ for batch in tqdm(twitter_dataloader, f'train-twitter-{epoch}', leave=False, disable=TQDM_DISABLE):
145
+ b_ids, b_mask = batch['token_ids'], batch['attention_mask']
146
+ b_ids = b_ids.to(device)
147
+ b_mask = b_mask.to(device)
148
+
149
+ optimizer_cse.zero_grad()
150
+ logits = model.bert.embed(b_ids)
151
+ logits = model.bert.encode(logits, b_mask)
152
+
153
+
154
+
155
+
156
+ def get_args():
157
+ parser = argparse.ArgumentParser()
158
+ parser.add_argument("--seed", type=int, default=11711)
159
+ parser.add_argument("--num-cpu-cores", type=int, default=4)
160
+ parser.add_argument("--epochs", type=int, default=10)
161
+ parser.add_argument("--use_gpu", action='store_true')
162
+ parser.add_argument("--batch_size_cse", help="'unsup': 64, 'sup': 512", type=int)
163
+ parser.add_argument("--batch_size_classifier", help="'sst': 64, 'cfimdb': 8", type=int)
164
+ parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
165
+ parser.add_argument("--lr_cse", default=2e-5)
166
+ parser.add_argument("--lr_classifier", default=1e-5)
167
+
168
+ args = parser.parse_args()
169
+ return args
170
+
171
+
172
+ if __name__ == "__main__":
173
+ args = get_args()
174
+ seed_everything(args.seed)
175
+ torch.set_num_threads(args.num_cpu_cores)
176
+
177
+ print('Finetuning minBERT with Unsupervised SimCSE...')
178
+ config = SimpleNamespace(
179
+ filepath='contrastive-nli.pt',
180
+ lr=args.lr,
181
+ num_cpu_cores=args.num_cpu_cores,
182
+ use_gpu=args.use_gpu,
183
+ epochs=args.epochs,
184
+ batch_size_cse=args.batch_size_cse,
185
+ batch_size_classifier=args.batch_size_classifier,
186
+ train_bert='data/twitter-unsup.csv',
187
+ train='data/ids-sst-train.csv',
188
+ dev='data/ids-sst-dev.csv',
189
+ test='data/ids-sst-test-student.csv',
190
+ dev_out = 'predictions/' + args.fine_tune_mode + '-sst-dev-out.csv',
191
+ test_out = 'predictions/' + args.fine_tune_mode + '-sst-test-out.csv'
192
+ )
193
+
194
+ train(config)
195
+
196
+ # model = BertModel.from_pretrained('bert-base-uncased')
197
+
198
+ # model.eval()
199
+
200
+ # s = set()
201
+ # for param in model.parameters():
202
+ # s.add(param.requires_grad)
203
+
204
+ # print(s)