hduc-le commited on
Commit
57010c7
·
1 Parent(s): 1202255

Upload model_token_cls.ipynb

Browse files
Files changed (1) hide show
  1. model_token_cls.ipynb +611 -0
model_token_cls.ipynb ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torch.nn as nn\n",
11
+ "import torch.nn.functional as F\n",
12
+ "\n",
13
+ "from transformers.modeling_outputs import (\n",
14
+ " Seq2SeqQuestionAnsweringModelOutput,\n",
15
+ " Seq2SeqSequenceClassifierOutput,\n",
16
+ " BaseModelOutput,\n",
17
+ ")\n",
18
+ "from transformers import (\n",
19
+ " T5ForQuestionAnswering,\n",
20
+ " T5PreTrainedModel,\n",
21
+ " MBartPreTrainedModel,\n",
22
+ " MBartModel,\n",
23
+ " T5Config,\n",
24
+ " T5Model,\n",
25
+ " T5EncoderModel,\n",
26
+ " get_scheduler\n",
27
+ ")\n",
28
+ "from tqdm import tqdm \n",
29
+ "from dataclasses import dataclass\n",
30
+ "from typing import List, Optional, Tuple, Union\n",
31
+ "\n",
32
+ "import numpy as np\n",
33
+ "import random\n",
34
+ "import os \n",
35
+ "from datetime import datetime\n",
36
+ "from torch.utils.data import DataLoader, Dataset\n",
37
+ "from transformers import AutoTokenizer\n",
38
+ "from sklearn.model_selection import train_test_split"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 2,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import json\n",
48
+ "import yaml\n",
49
+ "from addict import Dict\n",
50
+ "\n",
51
+ "\n",
52
+ "def load_json(file_path):\n",
53
+ " with open(file_path, \"r\", encoding=\"utf-8-sig\") as f:\n",
54
+ " data = json.load(f)\n",
55
+ " return data\n",
56
+ "\n",
57
+ "\n",
58
+ "def read_config(path):\n",
59
+ " # read yaml and return contents\n",
60
+ " with open(path, \"r\") as file:\n",
61
+ " try:\n",
62
+ " return Dict(yaml.safe_load(file))\n",
63
+ " except yaml.YAMLError as exc:\n",
64
+ " print(exc)\n",
65
+ "\n",
66
+ "\n",
67
+ "def batch_to_device(batch: dict, device: str):\n",
68
+ " for k in batch:\n",
69
+ " batch[k] = batch[k].to(device)\n",
70
+ " return batch\n",
71
+ "\n",
72
+ "\n",
73
+ "def save_json(obj, path):\n",
74
+ " with open(path, \"w\") as outfile:\n",
75
+ " json.dump(obj, outfile, ensure_ascii=False, indent=2)\n"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "@dataclass\n",
85
+ "class TokenClassificationOutput:\n",
86
+ " loss: Optional[torch.FloatTensor] = None\n",
87
+ " sent_loss: Optional[torch.FloatTensor] = None\n",
88
+ " token_loss: Optional[torch.FloatTensor] = None\n",
89
+ " claim_logits: torch.FloatTensor = None\n",
90
+ " evidence_logits: torch.FloatTensor = None\n"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "def random_seed(value):\n",
100
+ " torch.backends.cudnn.deterministic = True\n",
101
+ " torch.manual_seed(value)\n",
102
+ " torch.cuda.manual_seed(value)\n",
103
+ " np.random.seed(value)\n",
104
+ " random.seed(value)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "@dataclass \n",
114
+ "class TrainingArguments:\n",
115
+ " data_path = \"data/ise-dsc01-train.json\"\n",
116
+ " model_name = \"VietAI/vit5-base\"\n",
117
+ " tokenizer_name = \"VietAI/vit5-base\"\n",
118
+ " gradient_accumulation_steps = 8\n",
119
+ " gradient_checkpointing = False\n",
120
+ " num_epochs = 10\n",
121
+ " lr = 3.0e-5\n",
122
+ " weight_decay = 1.0e-2\n",
123
+ " scheduler_name = \"cosine\"\n",
124
+ " warmup_steps = 0\n",
125
+ " patience = 3\n",
126
+ " max_seq_length = 1024\n",
127
+ " seed = 1401\n",
128
+ " test_size = 0.1\n",
129
+ " train_batch_size = 1\n",
130
+ " val_batch_size = 1\n",
131
+ "\n",
132
+ " save_best = True\n",
133
+ "\n",
134
+ " freeze_backbone = False\n",
135
+ " freeze_encoder = False\n",
136
+ " freeze_decoder = False\n",
137
+ "\n",
138
+ "training_args = TrainingArguments()"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "_LABEL_MAPPING = {\"SUPPORTED\": 0, \"NEI\": 1, \"REFUTED\": 2}\n",
148
+ " \n",
149
+ "class TokenStanceDataset(Dataset):\n",
150
+ " def __init__(self, dataset, dataset_keys, tokenizer, max_seq_length=1024) -> None:\n",
151
+ " super().__init__()\n",
152
+ " self.tokenizer = tokenizer\n",
153
+ " self.max_seq_length = max_seq_length\n",
154
+ " self.dataset = dataset\n",
155
+ " self.dataset_keys = dataset_keys\n",
156
+ "\n",
157
+ " def __getitem__(self, idx):\n",
158
+ " data_id = self.dataset_keys[idx]\n",
159
+ " data_item = self.dataset[data_id]\n",
160
+ " \n",
161
+ " claim = data_item['claim']\n",
162
+ " evidence = data_item['evidence']\n",
163
+ " context = data_item['context']\n",
164
+ " \n",
165
+ " encodings = self.tokenizer(\n",
166
+ " context, \n",
167
+ " claim,\n",
168
+ " truncation=True, \n",
169
+ " padding=\"max_length\", \n",
170
+ " max_length=self.max_seq_length, \n",
171
+ " return_tensors=\"pt\"\n",
172
+ " )\n",
173
+ " \n",
174
+ " if evidence is None:\n",
175
+ " start_position, end_position = 0, 0\n",
176
+ " else:\n",
177
+ " start_idx = context.find(evidence)\n",
178
+ " end_idx = start_idx + len(evidence)\n",
179
+ " \n",
180
+ " evidence_start = start_idx\n",
181
+ " evidence_end = end_idx\n",
182
+ " \n",
183
+ " if context[start_idx: end_idx] == evidence:\n",
184
+ " evidence_end = end_idx\n",
185
+ " else:\n",
186
+ " for n in [1, 2]:\n",
187
+ " if context[start_idx-n: end_idx-n] == evidence:\n",
188
+ " evidence_start = start_idx - n\n",
189
+ " evidence_end = end_idx - n\n",
190
+ " \n",
191
+ " if evidence_start < 0:\n",
192
+ " evidence_start = 0\n",
193
+ " \n",
194
+ " if evidence_end < 0:\n",
195
+ " evidence_end = 0\n",
196
+ " \n",
197
+ " start_position = encodings.char_to_token(0, evidence_start)\n",
198
+ " end_position = encodings.char_to_token(0, evidence_end)\n",
199
+ " \n",
200
+ " trace_back = 1\n",
201
+ " while end_position is None:\n",
202
+ " end_position = encodings.char_to_token(0, evidence_end-trace_back)\n",
203
+ " trace_back += 1\n",
204
+ " \n",
205
+ " if start_position is None:\n",
206
+ " start_position = 0\n",
207
+ " end_position = 0\n",
208
+ " \n",
209
+ " evidence_labels = torch.zeros(self.max_seq_length,)\n",
210
+ " if end_position > 0:\n",
211
+ " evidence_labels[start_position: end_position] = 1\n",
212
+ " evidence_labels = evidence_labels.long()\n",
213
+ " \n",
214
+ " #print(\"====\")\n",
215
+ " #print(evidence)\n",
216
+ " #print(self.tokenizer.decode(encodings.input_ids[0][evidence_labels.bool()]))\n",
217
+ " \n",
218
+ " label = torch.tensor(_LABEL_MAPPING[data_item[\"verdict\"]], dtype=torch.long)\n",
219
+ " \n",
220
+ " return {\n",
221
+ " \"input_ids\": encodings.input_ids.squeeze(0),\n",
222
+ " \"attention_mask\": encodings.attention_mask.squeeze(0),\n",
223
+ " \"evidence_labels\": evidence_labels,\n",
224
+ " \"labels\": label\n",
225
+ " }\n",
226
+ "\n",
227
+ " def __len__(self):\n",
228
+ " return len(self.dataset)\n",
229
+ " \n",
230
+ " "
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "random_seed(training_args.seed)\n",
240
+ "\n",
241
+ "data = load_json(training_args.data_path)\n",
242
+ "\n",
243
+ "data_keys = list(data.keys())\n",
244
+ "\n",
245
+ "train_keys, dev_keys = train_test_split(\n",
246
+ " data_keys,\n",
247
+ " test_size=training_args.test_size,\n",
248
+ " random_state=training_args.seed,\n",
249
+ " shuffle=True,\n",
250
+ ")\n",
251
+ "\n",
252
+ "train_set = {k: v for k, v in data.items() if k in train_keys}\n",
253
+ "dev_set = {k: v for k, v in data.items() if k in dev_keys}\n",
254
+ "\n",
255
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
256
+ " training_args.tokenizer_name, use_fast=True\n",
257
+ ")\n",
258
+ "\n",
259
+ "train_dataset = TokenStanceDataset(\n",
260
+ " train_set, train_keys, tokenizer, training_args.max_seq_length\n",
261
+ ")\n",
262
+ "val_dataset = TokenStanceDataset(\n",
263
+ " dev_set, dev_keys, tokenizer, training_args.max_seq_length\n",
264
+ ")\n",
265
+ "\n",
266
+ "train_dataloader = DataLoader(\n",
267
+ " train_dataset, batch_size=training_args.train_batch_size, shuffle=True\n",
268
+ ")\n",
269
+ "val_dataloader = DataLoader(\n",
270
+ " val_dataset, batch_size=training_args.val_batch_size, shuffle=False\n",
271
+ ")\n"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "class T5FeedForwardHead(nn.Module):\n",
281
+ " \"\"\"Head for sentence-level classification tasks.\"\"\"\n",
282
+ "\n",
283
+ " def __init__(self, config, out_dim):\n",
284
+ " super().__init__()\n",
285
+ " self.dense = nn.Linear(config.d_model, config.d_model)\n",
286
+ " self.dropout = nn.Dropout(p=config.classifier_dropout)\n",
287
+ " self.out_proj = nn.Linear(config.d_model, out_dim)\n",
288
+ "\n",
289
+ " def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n",
290
+ " hidden_states = self.dropout(hidden_states)\n",
291
+ " hidden_states = self.dense(hidden_states)\n",
292
+ " hidden_states = torch.relu(hidden_states)\n",
293
+ " hidden_states = self.dropout(hidden_states)\n",
294
+ " hidden_states = self.out_proj(hidden_states)\n",
295
+ " return hidden_states\n",
296
+ "\n",
297
+ "\n",
298
+ "\n",
299
+ "class ViT5ForTokenClassification(T5PreTrainedModel):\n",
300
+ " def __init__(self, config):\n",
301
+ " super().__init__(config)\n",
302
+ " self.transformer = T5Model(config)\n",
303
+ " self.num_labels = 2\n",
304
+ " self.num_verdicts = 3\n",
305
+ " \n",
306
+ " self.verdict_head = T5FeedForwardHead(config, self.num_verdicts)\n",
307
+ " self.evidence_head = T5FeedForwardHead(config, self.num_labels)\n",
308
+ " \n",
309
+ " def forward(\n",
310
+ " self,\n",
311
+ " input_ids: torch.LongTensor = None,\n",
312
+ " attention_mask: Optional[torch.Tensor] = None,\n",
313
+ " decoder_input_ids: Optional[torch.LongTensor] = None,\n",
314
+ " decoder_attention_mask: Optional[torch.LongTensor] = None,\n",
315
+ " head_mask: Optional[torch.Tensor] = None,\n",
316
+ " decoder_head_mask: Optional[torch.Tensor] = None,\n",
317
+ " cross_attn_head_mask: Optional[torch.Tensor] = None,\n",
318
+ " encoder_outputs: Optional[List[torch.FloatTensor]] = None,\n",
319
+ " inputs_embeds: Optional[torch.FloatTensor] = None,\n",
320
+ " decoder_inputs_embeds: Optional[torch.FloatTensor] = None,\n",
321
+ " labels: Optional[torch.LongTensor] = None,\n",
322
+ " evidence_labels: Optional[torch.LongTensor] = None,\n",
323
+ " use_cache: Optional[bool] = None,\n",
324
+ " output_attentions: Optional[bool] = None,\n",
325
+ " output_hidden_states: Optional[bool] = None,\n",
326
+ " return_dict: Optional[bool] = None,\n",
327
+ " ):\n",
328
+ " r\"\"\"\n",
329
+ " labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n",
330
+ " Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n",
331
+ " config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n",
332
+ " Returns:\n",
333
+ " \"\"\"\n",
334
+ " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
335
+ " if labels is not None:\n",
336
+ " use_cache = False\n",
337
+ "\n",
338
+ " if input_ids is None and inputs_embeds is not None:\n",
339
+ " raise NotImplementedError(\n",
340
+ " f\"Passing input embeddings is currently not supported for {self.__class__.__name__}\"\n",
341
+ " )\n",
342
+ "\n",
343
+ " # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates\n",
344
+ " # decoder_input_ids from input_ids if no decoder_input_ids are provided\n",
345
+ " if decoder_input_ids is None and decoder_inputs_embeds is None:\n",
346
+ " if input_ids is None:\n",
347
+ " raise ValueError(\n",
348
+ " \"If no `decoder_input_ids` or `decoder_inputs_embeds` are \"\n",
349
+ " \"passed, `input_ids` cannot be `None`. Please pass either \"\n",
350
+ " \"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`.\"\n",
351
+ " )\n",
352
+ " decoder_input_ids = self._shift_right(input_ids)\n",
353
+ "\n",
354
+ " outputs = self.transformer(\n",
355
+ " input_ids,\n",
356
+ " attention_mask=attention_mask,\n",
357
+ " decoder_input_ids=decoder_input_ids,\n",
358
+ " decoder_attention_mask=decoder_attention_mask,\n",
359
+ " head_mask=head_mask,\n",
360
+ " decoder_head_mask=decoder_head_mask,\n",
361
+ " cross_attn_head_mask=cross_attn_head_mask,\n",
362
+ " encoder_outputs=encoder_outputs,\n",
363
+ " inputs_embeds=inputs_embeds,\n",
364
+ " decoder_inputs_embeds=decoder_inputs_embeds,\n",
365
+ " use_cache=use_cache,\n",
366
+ " output_attentions=output_attentions,\n",
367
+ " output_hidden_states=output_hidden_states,\n",
368
+ " return_dict=return_dict,\n",
369
+ " )\n",
370
+ " sequence_output = outputs[0] # (bsz, max_length, hidden_size)\n",
371
+ " \n",
372
+ " token_logits = self.evidence_head(sequence_output) # (bsz, max_length, 2)\n",
373
+ " token_loss = None\n",
374
+ " if evidence_labels is not None:\n",
375
+ " evidence_labels = evidence_labels.to(token_logits.device)\n",
376
+ " loss_fct = nn.CrossEntropyLoss()\n",
377
+ " token_loss = loss_fct(token_logits.view(-1, self.num_labels), evidence_labels.view(-1))\n",
378
+ " \n",
379
+ " eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)\n",
380
+ "\n",
381
+ " if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:\n",
382
+ " raise ValueError(\"All examples must have the same number of <eos> tokens.\")\n",
383
+ " batch_size, _, hidden_size = sequence_output.shape\n",
384
+ " sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] # (bsz, hidden_size)\n",
385
+ " sent_logits = self.verdict_head(sentence_representation)\n",
386
+ "\n",
387
+ " sent_loss = None\n",
388
+ " if labels is not None:\n",
389
+ " labels = labels.to(sent_logits.device)\n",
390
+ " if self.config.problem_type is None:\n",
391
+ " if self.config.num_labels == 1:\n",
392
+ " self.config.problem_type = \"regression\"\n",
393
+ " elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n",
394
+ " self.config.problem_type = \"single_label_classification\"\n",
395
+ " else:\n",
396
+ " self.config.problem_type = \"multi_label_classification\"\n",
397
+ "\n",
398
+ " if self.config.problem_type == \"regression\":\n",
399
+ " loss_fct = nn.MSELoss()\n",
400
+ " if self.config.num_labels == 1:\n",
401
+ " sent_loss = loss_fct(sent_logits.squeeze(), labels.squeeze())\n",
402
+ " else:\n",
403
+ " sent_loss = loss_fct(sent_logits, labels)\n",
404
+ " elif self.config.problem_type == \"single_label_classification\":\n",
405
+ " loss_fct = nn.CrossEntropyLoss()\n",
406
+ " sent_loss = loss_fct(sent_logits.view(-1, self.num_verdicts), labels.view(-1))\n",
407
+ " elif self.config.problem_type == \"multi_label_classification\":\n",
408
+ " loss_fct = nn.BCEWithLogitsLoss()\n",
409
+ " sent_loss = loss_fct(sent_logits, labels)\n",
410
+ " \n",
411
+ " \n",
412
+ " total_loss = None\n",
413
+ " if sent_loss is not None and token_loss is not None:\n",
414
+ " total_loss = 0.7*sent_loss + 0.3*token_loss\n",
415
+ " \n",
416
+ " \n",
417
+ " return TokenClassificationOutput(\n",
418
+ " loss=total_loss,\n",
419
+ " token_loss=token_loss,\n",
420
+ " sent_loss=sent_loss,\n",
421
+ " claim_logits=sent_logits,\n",
422
+ " evidence_logits=token_logits\n",
423
+ " )\n",
424
+ " \n",
425
+ " "
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "metadata": {},
432
+ "outputs": [],
433
+ "source": [
434
+ "def train(model, train_dataloader, val_dataloader, args):\n",
435
+ " print(f\"Mem needed: {model.get_memory_footprint() / 1024 / 1024 / 1024:.2f} GB\")\n",
436
+ " \n",
437
+ " # creating a tmp directory to save the models\n",
438
+ " out_dir = os.path.abspath(os.path.join(os.path.curdir, \"tmp-runs\", datetime.today().strftime('%a-%d-%b-%Y-%I:%M:%S%p')))\n",
439
+ "\n",
440
+ " # hparams\n",
441
+ " min_loss = float('inf')\n",
442
+ " sub_cycle = 0\n",
443
+ " best_path = None\n",
444
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
445
+ " \n",
446
+ " if args.freeze_backbone:\n",
447
+ " model.freeze_backbone()\n",
448
+ " \n",
449
+ " if args.freeze_encoder:\n",
450
+ " model.freeze_encoder()\n",
451
+ " \n",
452
+ " if args.freeze_decoder:\n",
453
+ " model.freeze_decoder()\n",
454
+ " \n",
455
+ " if args.gradient_checkpointing:\n",
456
+ " model.gradient_checkpointing_enable()\n",
457
+ " \n",
458
+ " total_num_steps = (len(train_dataloader) / args.gradient_accumulation_steps) * args.num_epochs\n",
459
+ " opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n",
460
+ " \n",
461
+ " sched = get_scheduler(\n",
462
+ " name=args.scheduler_name,\n",
463
+ " optimizer=opt,\n",
464
+ " num_warmup_steps=args.warmup_steps,\n",
465
+ " num_training_steps=total_num_steps,\n",
466
+ " )\n",
467
+ " \n",
468
+ " model.to(device)\n",
469
+ " \n",
470
+ " print(\"Start Training\")\n",
471
+ "\n",
472
+ " for ep in range(args.num_epochs):\n",
473
+ " model.train()\n",
474
+ " train_loss = 0.0\n",
475
+ " train_acc = {'qa': 0.0, 'cls': 0.0}\n",
476
+ " \n",
477
+ " for step, batch in enumerate(pbar := tqdm(train_dataloader, desc=f\"Epoch {ep} - training\")):\n",
478
+ " # transfer data to training device (gpu/cpu)\n",
479
+ " batch = batch_to_device(batch, device)\n",
480
+ " \n",
481
+ " # forward\n",
482
+ " outputs = model(**batch)\n",
483
+ " \n",
484
+ " # compute loss\n",
485
+ " loss = outputs.loss\n",
486
+ " \n",
487
+ " # gather metrics\n",
488
+ " train_loss += loss.item()\n",
489
+ "\n",
490
+ " # progress bar logging\n",
491
+ " pbar.set_postfix(loss=loss.item(), sent_loss=outputs.sent_loss.item(), token_loss=outputs.token_loss.item())\n",
492
+ "\n",
493
+ " # backward and optimize\n",
494
+ " loss.backward()\n",
495
+ " \n",
496
+ " if (step + 1) % args.gradient_accumulation_steps == 0 or (step+1) == len(train_dataloader):\n",
497
+ " opt.step()\n",
498
+ " sched.step()\n",
499
+ " opt.zero_grad()\n",
500
+ " \n",
501
+ " train_loss /= len(train_dataloader)\n",
502
+ " \n",
503
+ " # Evaluate at the end_acc of the epoch (distributed evaluation as we have all GPU cores)\n",
504
+ " model.eval()\n",
505
+ " val_loss = 0.0\n",
506
+ " \n",
507
+ " for batch in (pbar := tqdm(val_dataloader, desc=f\"Epoch {ep} - validation\")):\n",
508
+ " with torch.no_grad():\n",
509
+ " batch = batch_to_device(batch, device)\n",
510
+ " # forward\n",
511
+ " outputs = model(**batch)\n",
512
+ "\n",
513
+ " # compute loss\n",
514
+ " loss = outputs.loss\n",
515
+ "\n",
516
+ " # gather metrics\n",
517
+ " val_loss += loss.item()\n",
518
+ " \n",
519
+ " pbar.set_postfix(loss=loss.item(), sent_loss=outputs.sent_loss.item(), token_loss=outputs.token_loss.item())\n",
520
+ " \n",
521
+ " val_loss /= len(val_dataloader)\n",
522
+ " \n",
523
+ " print(f\"Summary epoch {ep}:\\n\" \n",
524
+ " f\"\\ttrain_loss: {train_loss:.4f} \\t val_loss: {val_loss:.4f}\")\n",
525
+ " \n",
526
+ " if val_loss < min_loss:\n",
527
+ " min_loss = val_loss\n",
528
+ " sub_cycle = 0\n",
529
+ " \n",
530
+ " best_path = os.path.join(out_dir, f\"epoch_{ep}\")\n",
531
+ " print(f\"Save cur model to {best_path}\")\n",
532
+ " \n",
533
+ " try:\n",
534
+ " model.push_to_hub('hduc-le/VyT5-Siamese-Fact-Check', private=True)\n",
535
+ " except: \n",
536
+ " print(\"Failed to push model to hub\")\n",
537
+ " pass\n",
538
+ " \n",
539
+ " model.save_pretrained(best_path)\n",
540
+ " \n",
541
+ " else:\n",
542
+ " sub_cycle += 1\n",
543
+ " if sub_cycle == args.patience:\n",
544
+ " print(\"Early stopping!\")\n",
545
+ " break\n",
546
+ " \n",
547
+ " print(\"End of training. Restore the best weights\")\n",
548
+ " best_model = ViT5ForTokenClassification.from_pretrained(best_path)\n",
549
+ " \n",
550
+ " if args.save_best:\n",
551
+ " # save the current model\n",
552
+ " out_dir = os.path.abspath(os.path.join(os.path.curdir, \"saved-runs\", datetime.today().strftime('%a-%d-%b-%Y-%I:%M:%S%p')))\n",
553
+ " \n",
554
+ " best_path = os.path.join(out_dir, 'best')\n",
555
+ " try:\n",
556
+ " model.push_to_hub('hduc-le/VyT5-SentToken-Classification', private=True)\n",
557
+ " except:\n",
558
+ " print(\"Failed to push model to hub\")\n",
559
+ " pass\n",
560
+ " \n",
561
+ " print(f\"Save best model to {best_path}\")\n",
562
+ " \n",
563
+ " best_model.save_pretrained(best_path)\n",
564
+ " \n",
565
+ " return \n"
566
+ ]
567
+ },
568
+ {
569
+ "cell_type": "code",
570
+ "execution_count": null,
571
+ "metadata": {},
572
+ "outputs": [],
573
+ "source": [
574
+ "model = ViT5ForTokenClassification.from_pretrained(\n",
575
+ " training_args.model_name, use_cache=False, output_hidden_states=True\n",
576
+ ")\n",
577
+ "print(model)"
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": null,
583
+ "metadata": {},
584
+ "outputs": [],
585
+ "source": [
586
+ "train(model, train_dataloader, val_dataloader, args=training_args)"
587
+ ]
588
+ }
589
+ ],
590
+ "metadata": {
591
+ "kernelspec": {
592
+ "display_name": "mlds",
593
+ "language": "python",
594
+ "name": "python3"
595
+ },
596
+ "language_info": {
597
+ "codemirror_mode": {
598
+ "name": "ipython",
599
+ "version": 3
600
+ },
601
+ "file_extension": ".py",
602
+ "mimetype": "text/x-python",
603
+ "name": "python",
604
+ "nbconvert_exporter": "python",
605
+ "pygments_lexer": "ipython3",
606
+ "version": "3.10.6"
607
+ }
608
+ },
609
+ "nbformat": 4,
610
+ "nbformat_minor": 2
611
+ }