Truong-Phuc Nguyen commited on
Commit
e7fcbb8
1 Parent(s): a4dc1dd

Update plms/language_model.py

Browse files
Files changed (1) hide show
  1. plms/language_model.py +761 -613
plms/language_model.py CHANGED
@@ -1,613 +1,761 @@
1
- import os
2
- import logging
3
- import pickle
4
- import re
5
- import urllib
6
- from itertools import chain
7
- from typing import List, Dict
8
- from multiprocessing import Pool
9
- import numpy as np
10
- from tqdm import tqdm
11
- import torch
12
- from torch.nn import functional
13
- import transformers
14
- from .exceptions import ExceedMaxLengthError, HighlightNotFoundError, AnswerNotFoundError
15
- from .spacy_module import SpacyPipeline, VALID_METHODS
16
-
17
- __all__ = ('TransformersQG', 'ADDITIONAL_SP_TOKENS', 'TASK_PREFIX', 'clean', 'internet_connection')
18
-
19
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # to turn off warning message
20
- TASK_PREFIX = {
21
- "ae": "extract answers",
22
- "qg": "generate question",
23
- "qag": "generate question and answer",
24
- "qa": "answer question"
25
- }
26
- CE_IGNORE_INDEX = -100
27
- ADDITIONAL_SP_TOKENS = {'hl': '<hl>'}
28
- NUM_WORKERS = int(os.getenv('NUM_WORKERS', '0'))
29
- PARALLEL_PROCESSING = bool(int(os.getenv('PARALLEL_PROCESSING', '0')))
30
- DEFAULT_MODELS = {
31
- 'vi': 'VietAI/vit5-base'
32
- }
33
-
34
- def pickle_save(obj, path: str):
35
- with open(path, "wb") as fp:
36
- pickle.dump(obj, fp)
37
-
38
-
39
- def pickle_load(path: str):
40
- with open(path, "rb") as fp: # Unpickling
41
- return pickle.load(fp)
42
-
43
-
44
- def clean(string):
45
- string = re.sub(r'\A\s*', '', string)
46
- string = re.sub(r'\s*\Z', '', string)
47
- if len(string) > 0:
48
- return string
49
- return None
50
-
51
-
52
- def internet_connection(host='http://google.com'):
53
- try:
54
- urllib.request.urlopen(host)
55
- return True
56
- except:
57
- return False
58
-
59
-
60
- def load_language_model(model_name,
61
- cache_dir: str = None,
62
- use_auth_token: bool = False,
63
- torch_dtype=None,
64
- device_map: str = None,
65
- low_cpu_mem_usage: bool = False):
66
- """ load language model from huggingface model hub """
67
- # tokenizer
68
- local_files_only = not internet_connection()
69
- tokenizer = transformers.AutoTokenizer.from_pretrained(
70
- model_name, cache_dir=cache_dir, local_files_only=local_files_only, use_auth_token=use_auth_token)
71
- config = transformers.AutoConfig.from_pretrained(
72
- model_name, local_files_only=local_files_only, cache_dir=cache_dir, use_auth_token=use_auth_token)
73
- # model
74
- if config.model_type == 't5': # T5 model requires T5ForConditionalGeneration class
75
- model_class = transformers.T5ForConditionalGeneration.from_pretrained
76
- elif config.model_type == 'mt5':
77
- model_class = transformers.MT5ForConditionalGeneration.from_pretrained
78
- elif config.model_type == 'bart':
79
- model_class = transformers.BartForConditionalGeneration.from_pretrained
80
- elif config.model_type == 'mbart':
81
- model_class = transformers.MBartForConditionalGeneration.from_pretrained
82
- elif config.model_type == 'switch_transformers':
83
- model_class = transformers.SwitchTransformersForConditionalGeneration.from_pretrained
84
- else:
85
- raise ValueError(f'unsupported model type: {config.model_type}')
86
-
87
- param = {'config': config, "local_files_only": local_files_only, "use_auth_token": use_auth_token,
88
- "low_cpu_mem_usage": low_cpu_mem_usage, "cache_dir": cache_dir}
89
- if torch_dtype is not None:
90
- param['torch_dtype'] = torch_dtype
91
- if device_map is not None:
92
- param['device_map'] = device_map
93
- model = model_class(model_name, **param)
94
- # add new special tokens to the tokenizer and the model if they don't have it
95
- tokenizer.add_special_tokens({'additional_special_tokens': list(ADDITIONAL_SP_TOKENS.values())})
96
- model.resize_token_embeddings(len(tokenizer))
97
- return tokenizer, model, config
98
-
99
-
100
- def label_smoothed_loss(logits, labels, epsilon):
101
- """ https://github.com/huggingface/transformers/blob/55bb4c06f7be141c6d895dbe1f11018dc8580b2d/src/transformers/trainer_pt_utils.py#L430 """
102
- log_probs = - functional.log_softmax(logits, dim=-1)
103
- if labels.dim() == log_probs.dim() - 1:
104
- labels = labels.unsqueeze(-1)
105
-
106
- padding_mask = labels.eq(CE_IGNORE_INDEX)
107
- # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
108
- # will ignore them in any case.
109
- labels.clamp_min_(0)
110
-
111
- nll_loss = log_probs.gather(dim=-1, index=labels)
112
- nll_loss.masked_fill_(padding_mask, 0.0)
113
-
114
- # works for fp16 input tensor too, by internally upcasting it to fp32
115
- smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
116
- smoothed_loss.masked_fill_(padding_mask, 0.0)
117
-
118
- # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
119
- num_active_elements = padding_mask.numel() - padding_mask.long().sum()
120
- nll_loss = nll_loss.sum() / num_active_elements
121
- smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
122
- return (1 - epsilon) * nll_loss + epsilon * smoothed_loss
123
-
124
-
125
- class Dataset(torch.utils.data.Dataset):
126
- """ torch.utils.data.Dataset wrapper converting into tensor """
127
- float_tensors = ['attention_mask']
128
-
129
- def __init__(self, data: List):
130
- self.data = data
131
-
132
- def __len__(self):
133
- return len(self.data)
134
-
135
- def to_tensor(self, name, data):
136
- if name in self.float_tensors:
137
- return torch.tensor(data, dtype=torch.float32)
138
- return torch.tensor(data, dtype=torch.long)
139
-
140
- def __getitem__(self, idx):
141
- return {k: self.to_tensor(k, v) for k, v in self.data[idx].items()}
142
-
143
-
144
- class EncodePlus:
145
- """ Wrapper of encode_plus for multiprocessing. """
146
-
147
- def __init__(self,
148
- tokenizer,
149
- max_length: int = 512,
150
- max_length_output: int = 34,
151
- drop_overflow_error_text: bool = False,
152
- skip_overflow_error: bool = False,
153
- drop_highlight_error_text: bool = False,
154
- prefix_type: str = None,
155
- padding: bool = True):
156
- """ Wrapper of encode_plus for multiprocessing.
157
-
158
- @param tokenizer: transforms.Tokenizer
159
- @param max_length: Max text length of input.
160
- @param max_length_output: Max text length of output.
161
- @param drop_overflow_error_text: If true, return None when the input exceeds the max length.
162
- @param skip_overflow_error: If true, raise an error when the input exceeds the max length.
163
- @param drop_highlight_error_text: If true, raise an error when a highlight span is not found in the paragraph.
164
- @param prefix_type: Either of `qg` or `answer_extraction`, which is to add at the beginning of the text.
165
- @param padding: Pad the sequence to the max length.
166
- """
167
- self.prefix = TASK_PREFIX[prefix_type] if prefix_type is not None else None
168
- self.tokenizer = tokenizer
169
- self.max_length = max_length
170
- self.max_length_output = max_length_output
171
- # NOTE: for model training, we should drop the exceeded input but not for the evaluator
172
- self.drop_overflow_error_text = drop_overflow_error_text
173
- self.skip_overflow_error = skip_overflow_error
174
- self.drop_highlight_error_text = drop_highlight_error_text
175
- # truncation should be true for the batch process, but not necessary to process single input
176
- self.param_in = {'truncation': True, 'max_length': self.max_length}
177
- self.param_out = {'truncation': True, 'max_length': self.max_length_output}
178
- if padding:
179
- self.param_in['padding'] = 'max_length'
180
- self.param_out['padding'] = 'max_length'
181
-
182
- def __call__(self, inputs):
183
- return self.encode_plus(*inputs)
184
-
185
- def encode_plus(self, input_sequence: str, output_sequence: str = None, input_highlight: str = None):
186
- """ encode_plus
187
-
188
- @param input_sequence: Input sequence.
189
- @param output_sequence: Output sequence.
190
- @param input_highlight: Sub-sequence of `input_sequence` to be surrounded by <hl>.
191
- @return: The output of `encode_plus`.
192
- """
193
- # add highlight to the input
194
- if input_highlight is not None:
195
- position = input_sequence.find(input_highlight)
196
- if position == -1:
197
- if self.drop_highlight_error_text:
198
- return None
199
- raise HighlightNotFoundError(input_highlight, input_sequence)
200
- input_sequence = '{0}{1} {2} {1}{3}'.format(
201
- input_sequence[:position], ADDITIONAL_SP_TOKENS['hl'], input_highlight,
202
- input_sequence[position+len(input_highlight):])
203
- if self.prefix is not None:
204
- input_sequence = f'{self.prefix}: {input_sequence}'
205
-
206
- # handling overflow text
207
- # drop_overflow_error_text ==> remove the overflow sentence from input
208
- # skip_overflow_error ==> keep the overflow sentence
209
- # none of them ==> raise error
210
- if self.drop_overflow_error_text or not self.skip_overflow_error:
211
- if len(self.tokenizer.encode(input_sequence)) > self.max_length:
212
- if not self.drop_overflow_error_text: # raise error for overflow text
213
- raise ExceedMaxLengthError(self.max_length)
214
- return None # remove overflow text
215
- if output_sequence is not None:
216
- if len(self.tokenizer.encode(output_sequence)) > self.max_length_output:
217
- if not self.drop_overflow_error_text: # raise error for overflow text
218
- raise ExceedMaxLengthError(self.max_length)
219
- return None # remove overflow text
220
- if type(self.tokenizer) is transformers.models.mbart.tokenization_mbart_fast.MBartTokenizerFast:
221
- encode = self.tokenizer(input_sequence, **self.param_in)
222
- else:
223
- encode = self.tokenizer(text_target=input_sequence, **self.param_in)
224
- if output_sequence is not None:
225
- encode['labels'] = self.tokenizer.encode(output_sequence, **self.param_out)
226
- return encode
227
-
228
-
229
- class TransformersQG:
230
- """ Transformers Language Model for Question Generation. """
231
-
232
- def __init__(self,
233
- model: str = None,
234
- max_length: int = 512,
235
- max_length_output: int = 256,
236
- model_ae: str = None,
237
- max_length_ae: int = 512,
238
- max_length_output_ae: int = 64,
239
- cache_dir: str = None,
240
- add_prefix: bool = None,
241
- language: str = 'vi',
242
- label_smoothing: float = None,
243
- skip_overflow_error: bool = False,
244
- drop_overflow_error_text: bool = False,
245
- drop_highlight_error_text: bool = False,
246
- drop_answer_error_text: bool = False,
247
- use_auth_token: bool = False,
248
- torch_dtype=None,
249
- device_map: str = None,
250
- low_cpu_mem_usage: bool = False,
251
- is_qg: bool = None,
252
- is_qag: bool = None,
253
- is_qa: bool = None,
254
- is_ae: bool = None):
255
- """ Transformers Language Model for Question Generation.
256
-
257
- @param model: Model alias or path to local model file.
258
- @param max_length: Max text length of input.
259
- @param max_length_output: Max text length of output.
260
- @param cache_dir: Directory to cache transformers model files.
261
- @param add_prefix: Whether model uses task-specific prefix (eg. True for T5 but False for BART models).
262
- @param language: Language alias for SpaCy language-specific pipelines (sentencizer/keyword extraction).
263
- @param label_smoothing: [Fine-tuning parameter] Label smoothing.
264
- @param drop_overflow_error_text: If true, return None when the input exceeds the max length.
265
- @param skip_overflow_error: If true, raise an error when the input exceeds the max length.
266
- @param drop_highlight_error_text: If true, raise an error when a highlight span is not found in the paragraph.
267
- @param use_auth_token: [optional] Huggingface transformers argument of `use_auth_token`
268
- """
269
-
270
- # take default model given the language
271
- if model is None:
272
- assert language in DEFAULT_MODELS.keys(),\
273
- f"Model with language '{language}' is not available. Please choose language from " \
274
- f"'{DEFAULT_MODELS.keys()}' or specify 'model'."
275
- model = DEFAULT_MODELS[language]
276
-
277
- # classify model type
278
- self.is_qg = 'qg' in model.split('-') if is_qg is None else is_qg
279
- self.is_ae = 'ae' in model.split('-') if is_ae is None else is_ae
280
- self.is_qa = 'qa' in model.split('-') if is_qa is None else is_qa
281
- self.is_qag = 'qag' in model.split('-') if is_qag is None else is_qag
282
- # configs
283
- self.model_name = model
284
- self.max_length = max_length
285
- self.max_length_output = max_length_output
286
- self.label_smoothing = label_smoothing
287
- self.drop_overflow_error_text = drop_overflow_error_text
288
- self.skip_overflow_error = skip_overflow_error
289
- self.drop_highlight_error_text = drop_highlight_error_text
290
- self.drop_answer_error_text = drop_answer_error_text
291
- self.model_name_ae = model_ae
292
- self.max_length_ae = max_length_ae
293
- self.max_length_output_ae = max_length_output_ae
294
- # load model
295
- self.tokenizer, self.model, config = load_language_model(
296
- self.model_name, cache_dir=cache_dir, use_auth_token=use_auth_token, device_map=device_map,
297
- torch_dtype=torch_dtype, low_cpu_mem_usage=low_cpu_mem_usage)
298
- if 'add_prefix' not in config.to_dict().keys():
299
- # this means the model is not fine-tuned
300
- # assert add_prefix, '`add_prefix` is required for non-fine-tuned models'
301
- self.add_prefix = add_prefix
302
- else:
303
- self.add_prefix = config.add_prefix
304
-
305
- # set default behaviour for answer extraction
306
- if self.model_name_ae is None:
307
- self.model_name_ae = self.model_name if self.is_ae else "positionrank"
308
- # load answer extraction model
309
- self.answer_model_type = None
310
- if self.model_name_ae in VALID_METHODS:
311
- logging.info(f'use spaCy answer extraction model: {self.model_name_ae}')
312
- self.tokenizer_ae = self.model_ae = self.add_prefix_ae = None
313
- self.spacy_module = SpacyPipeline(language, self.model_name_ae)
314
- self.answer_model_type = 'spacy'
315
- else:
316
- logging.info(f'use LMQG fine-tuned answer extraction model: {self.model_name_ae}')
317
- if self.model_name == self.model_name_ae:
318
- logging.info("the same model as QG is used as AE")
319
- assert self.is_ae, f"the model ({self.model_name_ae}) is not fine-tuned for AE"
320
- self.tokenizer_ae = self.model_ae = self.add_prefix_ae = None
321
- self.answer_model_type = 'multitask'
322
- else:
323
- logging.info(f"loading 2nd model for AE: {self.model_name_ae}")
324
- self.tokenizer_ae, self.model_ae, config_ae = load_language_model(model_ae, cache_dir=cache_dir, use_auth_token=use_auth_token)
325
- self.add_prefix_ae = config_ae.add_prefix
326
- self.answer_model_type = 'pipeline'
327
- self.spacy_module = SpacyPipeline(language)
328
-
329
- # GPU setup
330
- self.device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
331
- self.parallel = False
332
- if torch.cuda.device_count() > 1:
333
- self.parallel = True
334
- self.model = torch.nn.DataParallel(self.model)
335
- if self.model_ae is not None:
336
- self.model_ae = torch.nn.DataParallel(self.model_ae)
337
- self.model.to(self.device)
338
- if self.model_ae is not None:
339
- self.model_ae.to(self.device)
340
- logging.info(f'Model `{self.model_name}`')
341
- logging.info(f'\t * Num of GPU in use: {torch.cuda.device_count()}')
342
- logging.info(f'\t * Prefix: {self.add_prefix}')
343
- logging.info(f'\t * Language: {language} (ignore at the training phase)')
344
-
345
- def generate_qa_end2end(self,
346
- list_context: str or List,
347
- batch_size: int = None,
348
- num_beams: int = 4,
349
- cache_path: str = None,
350
- splitting_symbol: str = ' [SEP] ',
351
- question_prefix: str = "question: ",
352
- answer_prefix: str = ", answer: "):
353
- """ Generate question from paragraph and answer. Note that `list_answer` is needed unless they are already
354
- highlighted in the `list_context`. eg) "I live in <hl> Tokyo <hl>."
355
-
356
- @param list_context: List of input texts.
357
- @param batch_size: Batch size.
358
- @param num_beams: Number of beam for model generation.
359
- @param cache_path: Path to pre-compute features.
360
- @return: List of generated sentences.
361
- """
362
- logging.info(f'running model for `question_answer_pair_generation`')
363
- assert self.is_qag, "`generate_qa_end2end` is available for end2end_qag_model"
364
- prefix_type = 'qag' if self.add_prefix else None
365
- single_input = type(list_context) is str
366
- list_context = [list_context] if single_input else list_context
367
- output = self.generate_prediction(
368
- list_context, prefix_type=prefix_type, cache_path=cache_path, num_beams=num_beams, batch_size=batch_size
369
- )
370
-
371
- def format_qa(list_raw_string):
372
- tmp = []
373
- for raw_string in list_raw_string:
374
- if len(raw_string.split(answer_prefix)) != 2 or question_prefix not in raw_string:
375
- logging.info(f"invalid prediction: {raw_string}")
376
- else:
377
- q, a = raw_string.split(answer_prefix)
378
- a = re.sub(r'\A\s+', '', a)
379
- a = re.sub(r'\s+\Z', '', a)
380
- q = q.replace(question_prefix, "")
381
- q = re.sub(r'\A\s+', '', q)
382
- q = re.sub(r'\s+\Z', '', q)
383
- tmp.append((q, a))
384
- return tmp
385
-
386
- output = [format_qa(o.split(splitting_symbol)) for o in output]
387
- return output[0] if single_input else output
388
-
389
- def generate_qa(self,
390
- list_context: str or List,
391
- batch_size: int = None,
392
- num_beams: int = 4,
393
- cache_path: str = None,
394
- num_questions: int = None,
395
- sentence_level: bool = False):
396
- """ Generate question given context.
397
-
398
- @param list_context: Input text.
399
- @param batch_size: Batch size.
400
- @param num_beams: Number of beam for model generation.
401
- @param cache_path: Path to pre-compute features.
402
- @param num_questions: Max number of questions.
403
- @param sentence_level: Run prediction on each sentence of the context independently to reduce complexity.
404
- @return: List of generated sentences.
405
- """
406
- if self.is_qag:
407
- return self.generate_qa_end2end(list_context, batch_size, num_beams, cache_path)
408
- single_input = type(list_context) is str
409
- list_context = [list_context] if single_input else list_context
410
- original_input_length = len(list_context)
411
-
412
- logging.info('running model for `ae`')
413
- list_answer = self.generate_a(
414
- list_context,
415
- batch_size=batch_size,
416
- num_beams=num_beams,
417
- cache_path=cache_path,
418
- sentence_level=sentence_level,
419
- num_questions=num_questions
420
- )
421
- valid_context_id = [n for n, a in enumerate(list_answer) if a is not None]
422
- list_context = [list_context[n] for n in valid_context_id]
423
- list_answer = [list_answer[n] for n in valid_context_id]
424
- qg_input, qg_hl, list_length = [], [], [0]
425
- for c, a in zip(list_context, list_answer):
426
- qg_hl += a
427
- qg_input += [c] * len(a)
428
- list_length.append(list_length[-1] + len(a))
429
- logging.info('running model for `qg`')
430
- list_question = self.generate_q(
431
- qg_input,
432
- list_answer=qg_hl,
433
- batch_size=batch_size,
434
- cache_path=cache_path,
435
- num_beams=num_beams,
436
- sentence_level=sentence_level
437
- )
438
-
439
- assert len(qg_hl) == len(list_question), f"{len(qg_input)} != {len(list_question)}"
440
-
441
- # return to nested list
442
- list_question = [list_question[list_length[n - 1]:list_length[n]] for n in range(1, len(list_length))]
443
- list_answer = [qg_hl[list_length[n - 1]:list_length[n]] for n in range(1, len(list_length))]
444
- output_list = [None] * original_input_length
445
-
446
- for n, _id in enumerate(valid_context_id):
447
- output_list[_id] = [(q, a) for q, a in zip(list_question[n], list_answer[n])]
448
- return output_list[0] if single_input else output_list
449
-
450
- def generate_a(self,
451
- context: str or List,
452
- batch_size: int = None,
453
- num_beams: int = 4,
454
- cache_path: str = None,
455
- sentence_level: bool = False,
456
- num_questions: int = None):
457
- """ Generate answers from each sentence.
458
-
459
- @param context: Input text.
460
- @param batch_size: Batch size.
461
- @param num_beams: Number of beam for model generation.
462
- @param cache_path: Path to pre-compute features.
463
- @param sentence_level: Run prediction on each sentence of the context independently to reduce complexity.
464
- @param num_questions: Max number of questions.
465
- @return: List of generated answers.
466
- """
467
- logging.info(f'running model for `answer_extraction`')
468
- if self.answer_model_type == 'spacy':
469
- num_questions = 10 if num_questions is None else num_questions
470
- if type(context) is str:
471
- return self.spacy_module.keyword(context, num_questions)
472
- else:
473
- return [self.spacy_module.keyword(c, num_questions) for c in context]
474
- single_input = type(context) is str
475
- context = [context] if single_input else context
476
- list_sentences = [self.spacy_module.sentence(c) for c in context] # split into sentence
477
- list_inputs = [[c] * len(s) for c, s in zip(context, list_sentences)]
478
- list_length = [0] + np.cumsum([len(s) for s in list_sentences]).tolist()
479
- if sentence_level:
480
- list_inputs = list_sentences
481
- # flatten inputs
482
- flat_sentences = list(chain(*list_sentences))
483
- flat_inputs = list(chain(*list_inputs))
484
- if self.answer_model_type == 'multitask':
485
- answer = self.generate_prediction(
486
- flat_inputs, # list_input,
487
- highlights=flat_sentences, # highlights=list_sentence,
488
- prefix_type='ae' if self.add_prefix else None,
489
- cache_path=cache_path,
490
- num_beams=num_beams,
491
- batch_size=batch_size
492
- )
493
- elif self.answer_model_type == 'pipeline':
494
- answer = self.generate_prediction(
495
- flat_inputs, # list_input,
496
- highlights=flat_sentences, # highlights=list_sentence,
497
- prefix_type='ae' if self.add_prefix_ae else None,
498
- cache_path=cache_path,
499
- num_beams=num_beams,
500
- batch_size=batch_size,
501
- switch_to_model_ae=True
502
- )
503
- else:
504
- raise ValueError(f"unknown answer model type: {self.answer_model_type}")
505
- # return to nested list
506
- answer = [clean(a) for a in answer]
507
- list_answer = [answer[list_length[n - 1]:list_length[n]] for n in range(1, len(list_length))]
508
- list_answer = [[a for a, c in zip(a_sent, c_sent) if a is not None and a in c]
509
- for a_sent, c_sent in zip(list_answer, list_inputs)]
510
- list_answer = [None if len(a) == 0 else a for a in list_answer]
511
- if not self.drop_answer_error_text:
512
- if any(a is None for a in list_answer):
513
- raise AnswerNotFoundError([context[n] for n, a in enumerate(list_answer) if a is None][0])
514
- return list_answer[0] if single_input else list_answer
515
-
516
- def generate_q(self,
517
- list_context: str or List,
518
- list_answer: List = None,
519
- batch_size: int = None,
520
- num_beams: int = 4,
521
- cache_path: str = None,
522
- sentence_level: bool = False):
523
- """ Generate question from paragraph and answer. Note that `list_answer` is needed unless they are already
524
- highlighted in the `list_context`. eg) "I live in <hl> Tokyo <hl>."
525
-
526
- @param list_context: List of input texts.
527
- @param list_answer: List of answers in the `list_context` that are highlighted by <hl>.
528
- @param batch_size: Batch size.
529
- @param num_beams: Number of beam for model generation.
530
- @param cache_path: Path to pre-compute features.
531
- @param sentence_level: Run prediction on each sentence of the context independently to reduce complexity.
532
- @return: List of generated sentences.
533
- """
534
- assert self.is_qg, "model is not fine-tuned for QG"
535
- if list_answer is not None:
536
- assert type(list_context) is type(list_answer), f"{type(list_context)} != {type(list_answer)}"
537
- single_input = False
538
- if type(list_context) is str:
539
- list_context = [list_context]
540
- list_answer = [list_answer] if list_answer is not None else None
541
- single_input = True
542
- output = self.generate_prediction(
543
- list_context,
544
- highlights=list_answer,
545
- prefix_type='qg' if self.add_prefix else None,
546
- cache_path=cache_path,
547
- num_beams=num_beams,
548
- batch_size=batch_size,
549
- sentence_level=sentence_level
550
- )
551
- if single_input:
552
- return output[0]
553
- return output
554
-
555
- def generate_prediction(self,
556
- inputs: List,
557
- highlights: List or None = None,
558
- prefix_type: str = None,
559
- num_beams: int = 4,
560
- batch_size: int = None,
561
- cache_path: str = None,
562
- sentence_level: bool = False,
563
- switch_to_model_ae: bool = False):
564
- """ General method to generate model prediction
565
-
566
- @param inputs: List of input sequences.
567
- @param highlights: List of sub-sequences from list_context to be highlighted by <hl>.
568
- @param batch_size: Batch size.
569
- @param num_beams: Number of beam for model generation.
570
- @param cache_path: Path to pre-compute features.
571
- @param prefix_type: Either of `qg` or `answer_extraction`, which is to add at the beginning of the text.
572
- @return: List of generated sequences.
573
- """
574
- self.eval()
575
- if switch_to_model_ae:
576
- assert self.model_ae is not None and self.tokenizer_ae is not None
577
- model = self.model_ae
578
- tokenizer = self.tokenizer_ae
579
- max_length_output = self.max_length_output_ae
580
- else:
581
- model = self.model
582
- tokenizer = self.tokenizer
583
- max_length_output = self.max_length_output
584
-
585
- if sentence_level:
586
- assert highlights is not None, '`sentence_level` needs `highlights`.'
587
- assert len(highlights) == len(inputs), str([len(highlights), len(inputs)])
588
- list_sentence = []
589
- for context, answer in zip(inputs, highlights):
590
- s = [sentence for sentence in self.spacy_module.sentence(context) if answer in sentence]
591
- list_sentence.append(s[0] if len(s) != 0 else context)
592
- inputs = list_sentence
593
-
594
- assert type(inputs) is list, inputs
595
- encode_list = self.text_to_encode(
596
- inputs,
597
- highlights=highlights,
598
- prefix_type=prefix_type,
599
- cache_path=cache_path,
600
- switch_to_model_ae=switch_to_model_ae
601
- )
602
- loader = self.get_data_loader(encode_list, batch_size=batch_size)
603
- outputs = []
604
- for encode in loader:
605
- with torch.no_grad():
606
- if 'labels' in encode:
607
- encode.pop('labels')
608
- encode = {k: v.to(self.device) for k, v in encode.items()}
609
- encode['max_length'] = max_length_output
610
- encode['num_beams'] = num_beams
611
- tensor = model.module.generate(**encode) if self.parallel else model.generate(**encode)
612
- outputs += tokenizer.batch_decode(tensor, skip_special_tokens=True)
613
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import pickle
4
+ import re
5
+ import urllib
6
+ from itertools import chain
7
+ from typing import List, Dict
8
+ from multiprocessing import Pool
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import torch
12
+ from torch.nn import functional
13
+ import transformers
14
+ from .exceptions import ExceedMaxLengthError, HighlightNotFoundError, AnswerNotFoundError
15
+ from .spacy_module import SpacyPipeline, VALID_METHODS
16
+
17
+ __all__ = ('TransformersQG', 'ADDITIONAL_SP_TOKENS', 'TASK_PREFIX', 'clean', 'internet_connection')
18
+
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # to turn off warning message
20
+ TASK_PREFIX = {
21
+ "ae": "extract answers",
22
+ "qg": "generate question",
23
+ "qag": "generate question and answer",
24
+ "qa": "answer question"
25
+ }
26
+ CE_IGNORE_INDEX = -100
27
+ ADDITIONAL_SP_TOKENS = {'hl': '<hl>'}
28
+ NUM_WORKERS = int(os.getenv('NUM_WORKERS', '0'))
29
+ PARALLEL_PROCESSING = bool(int(os.getenv('PARALLEL_PROCESSING', '0')))
30
+ DEFAULT_MODELS = {
31
+ 'vi': 'VietAI/vit5-base'
32
+ }
33
+
34
+
35
+ def pickle_save(obj, path: str):
36
+ with open(path, "wb") as fp:
37
+ pickle.dump(obj, fp)
38
+
39
+
40
+ def pickle_load(path: str):
41
+ with open(path, "rb") as fp: # Unpickling
42
+ return pickle.load(fp)
43
+
44
+
45
+ def clean(string):
46
+ string = re.sub(r'\A\s*', '', string)
47
+ string = re.sub(r'\s*\Z', '', string)
48
+ if len(string) > 0:
49
+ return string
50
+ return None
51
+
52
+
53
+ def internet_connection(host='http://google.com'):
54
+ try:
55
+ urllib.request.urlopen(host)
56
+ return True
57
+ except:
58
+ return False
59
+
60
+
61
+ def load_language_model(model_name,
62
+ cache_dir: str = None,
63
+ use_auth_token: bool = False,
64
+ torch_dtype=None,
65
+ device_map: str = None,
66
+ low_cpu_mem_usage: bool = False):
67
+ """ load language model from huggingface model hub """
68
+ # tokenizer
69
+ local_files_only = not internet_connection()
70
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
71
+ model_name, cache_dir=cache_dir, local_files_only=local_files_only, use_auth_token=use_auth_token)
72
+ config = transformers.AutoConfig.from_pretrained(
73
+ model_name, local_files_only=local_files_only, cache_dir=cache_dir, use_auth_token=use_auth_token)
74
+ # model
75
+ if config.model_type == 't5': # T5 model requires T5ForConditionalGeneration class
76
+ model_class = transformers.T5ForConditionalGeneration.from_pretrained
77
+ elif config.model_type == 'mt5':
78
+ model_class = transformers.MT5ForConditionalGeneration.from_pretrained
79
+ elif config.model_type == 'bart':
80
+ model_class = transformers.BartForConditionalGeneration.from_pretrained
81
+ elif config.model_type == 'mbart':
82
+ model_class = transformers.MBartForConditionalGeneration.from_pretrained
83
+ elif config.model_type == 'switch_transformers':
84
+ model_class = transformers.SwitchTransformersForConditionalGeneration.from_pretrained
85
+ else:
86
+ raise ValueError(f'unsupported model type: {config.model_type}')
87
+
88
+ param = {'config': config, "local_files_only": local_files_only, "use_auth_token": use_auth_token,
89
+ "low_cpu_mem_usage": low_cpu_mem_usage, "cache_dir": cache_dir}
90
+ if torch_dtype is not None:
91
+ param['torch_dtype'] = torch_dtype
92
+ if device_map is not None:
93
+ param['device_map'] = device_map
94
+ model = model_class(model_name, **param)
95
+ # add new special tokens to the tokenizer and the model if they don't have it
96
+ tokenizer.add_special_tokens({'additional_special_tokens': list(ADDITIONAL_SP_TOKENS.values())})
97
+ model.resize_token_embeddings(len(tokenizer))
98
+ return tokenizer, model, config
99
+
100
+
101
+ def label_smoothed_loss(logits, labels, epsilon):
102
+ """ https://github.com/huggingface/transformers/blob/55bb4c06f7be141c6d895dbe1f11018dc8580b2d/src/transformers/trainer_pt_utils.py#L430 """
103
+ log_probs = - functional.log_softmax(logits, dim=-1)
104
+ if labels.dim() == log_probs.dim() - 1:
105
+ labels = labels.unsqueeze(-1)
106
+
107
+ padding_mask = labels.eq(CE_IGNORE_INDEX)
108
+ # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
109
+ # will ignore them in any case.
110
+ labels.clamp_min_(0)
111
+
112
+ nll_loss = log_probs.gather(dim=-1, index=labels)
113
+ nll_loss.masked_fill_(padding_mask, 0.0)
114
+
115
+ # works for fp16 input tensor too, by internally upcasting it to fp32
116
+ smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
117
+ smoothed_loss.masked_fill_(padding_mask, 0.0)
118
+
119
+ # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
120
+ num_active_elements = padding_mask.numel() - padding_mask.long().sum()
121
+ nll_loss = nll_loss.sum() / num_active_elements
122
+ smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
123
+ return (1 - epsilon) * nll_loss + epsilon * smoothed_loss
124
+
125
+
126
+ class Dataset(torch.utils.data.Dataset):
127
+ """ torch.utils.data.Dataset wrapper converting into tensor """
128
+ float_tensors = ['attention_mask']
129
+
130
+ def __init__(self, data: List):
131
+ self.data = data
132
+
133
+ def __len__(self):
134
+ return len(self.data)
135
+
136
+ def to_tensor(self, name, data):
137
+ if name in self.float_tensors:
138
+ return torch.tensor(data, dtype=torch.float32)
139
+ return torch.tensor(data, dtype=torch.long)
140
+
141
+ def __getitem__(self, idx):
142
+ return {k: self.to_tensor(k, v) for k, v in self.data[idx].items()}
143
+
144
+
145
+ class EncodePlus:
146
+ """ Wrapper of encode_plus for multiprocessing. """
147
+
148
+ def __init__(self,
149
+ tokenizer,
150
+ max_length: int = 512,
151
+ max_length_output: int = 34,
152
+ drop_overflow_error_text: bool = False,
153
+ skip_overflow_error: bool = False,
154
+ drop_highlight_error_text: bool = False,
155
+ prefix_type: str = None,
156
+ padding: bool = True):
157
+ """ Wrapper of encode_plus for multiprocessing.
158
+
159
+ @param tokenizer: transforms.Tokenizer
160
+ @param max_length: Max text length of input.
161
+ @param max_length_output: Max text length of output.
162
+ @param drop_overflow_error_text: If true, return None when the input exceeds the max length.
163
+ @param skip_overflow_error: If true, raise an error when the input exceeds the max length.
164
+ @param drop_highlight_error_text: If true, raise an error when a highlight span is not found in the paragraph.
165
+ @param prefix_type: Either of `qg` or `answer_extraction`, which is to add at the beginning of the text.
166
+ @param padding: Pad the sequence to the max length.
167
+ """
168
+ self.prefix = TASK_PREFIX[prefix_type] if prefix_type is not None else None
169
+ self.tokenizer = tokenizer
170
+ self.max_length = max_length
171
+ self.max_length_output = max_length_output
172
+ # NOTE: for model training, we should drop the exceeded input but not for the evaluator
173
+ self.drop_overflow_error_text = drop_overflow_error_text
174
+ self.skip_overflow_error = skip_overflow_error
175
+ self.drop_highlight_error_text = drop_highlight_error_text
176
+ # truncation should be true for the batch process, but not necessary to process single input
177
+ self.param_in = {'truncation': True, 'max_length': self.max_length}
178
+ self.param_out = {'truncation': True, 'max_length': self.max_length_output}
179
+ if padding:
180
+ self.param_in['padding'] = 'max_length'
181
+ self.param_out['padding'] = 'max_length'
182
+
183
+ def __call__(self, inputs):
184
+ return self.encode_plus(*inputs)
185
+
186
+ def encode_plus(self, input_sequence: str, output_sequence: str = None, input_highlight: str = None):
187
+ """ encode_plus
188
+
189
+ @param input_sequence: Input sequence.
190
+ @param output_sequence: Output sequence.
191
+ @param input_highlight: Sub-sequence of `input_sequence` to be surrounded by <hl>.
192
+ @return: The output of `encode_plus`.
193
+ """
194
+ # add highlight to the input
195
+ if input_highlight is not None:
196
+ position = input_sequence.find(input_highlight)
197
+ if position == -1:
198
+ if self.drop_highlight_error_text:
199
+ return None
200
+ raise HighlightNotFoundError(input_highlight, input_sequence)
201
+ input_sequence = '{0}{1} {2} {1}{3}'.format(
202
+ input_sequence[:position], ADDITIONAL_SP_TOKENS['hl'], input_highlight,
203
+ input_sequence[position+len(input_highlight):])
204
+ if self.prefix is not None:
205
+ input_sequence = f'{self.prefix}: {input_sequence}'
206
+
207
+ # handling overflow text
208
+ # drop_overflow_error_text ==> remove the overflow sentence from input
209
+ # skip_overflow_error ==> keep the overflow sentence
210
+ # none of them ==> raise error
211
+ if self.drop_overflow_error_text or not self.skip_overflow_error:
212
+ if len(self.tokenizer.encode(input_sequence)) > self.max_length:
213
+ if not self.drop_overflow_error_text: # raise error for overflow text
214
+ raise ExceedMaxLengthError(self.max_length)
215
+ return None # remove overflow text
216
+ if output_sequence is not None:
217
+ if len(self.tokenizer.encode(output_sequence)) > self.max_length_output:
218
+ if not self.drop_overflow_error_text: # raise error for overflow text
219
+ raise ExceedMaxLengthError(self.max_length)
220
+ return None # remove overflow text
221
+ if type(self.tokenizer) is transformers.models.mbart.tokenization_mbart_fast.MBartTokenizerFast:
222
+ encode = self.tokenizer(input_sequence, **self.param_in)
223
+ else:
224
+ encode = self.tokenizer(text_target=input_sequence, **self.param_in)
225
+ if output_sequence is not None:
226
+ encode['labels'] = self.tokenizer.encode(output_sequence, **self.param_out)
227
+ return encode
228
+
229
+
230
+ class TransformersQG:
231
+ """ Transformers Language Model for Question Generation. """
232
+
233
+ def __init__(self,
234
+ model: str = None,
235
+ max_length: int = 512,
236
+ max_length_output: int = 256,
237
+ model_ae: str = None,
238
+ max_length_ae: int = 512,
239
+ max_length_output_ae: int = 64,
240
+ cache_dir: str = None,
241
+ add_prefix: bool = None,
242
+ language: str = 'vi',
243
+ label_smoothing: float = None,
244
+ skip_overflow_error: bool = False,
245
+ drop_overflow_error_text: bool = False,
246
+ drop_highlight_error_text: bool = False,
247
+ drop_answer_error_text: bool = False,
248
+ use_auth_token: bool = False,
249
+ torch_dtype=None,
250
+ device_map: str = None,
251
+ low_cpu_mem_usage: bool = False,
252
+ is_qg: bool = None,
253
+ is_qag: bool = None,
254
+ is_qa: bool = None,
255
+ is_ae: bool = None):
256
+ """ Transformers Language Model for Question Generation.
257
+
258
+ @param model: Model alias or path to local model file.
259
+ @param max_length: Max text length of input.
260
+ @param max_length_output: Max text length of output.
261
+ @param cache_dir: Directory to cache transformers model files.
262
+ @param add_prefix: Whether model uses task-specific prefix (eg. True for T5 but False for BART models).
263
+ @param language: Language alias for SpaCy language-specific pipelines (sentencizer/keyword extraction).
264
+ @param label_smoothing: [Fine-tuning parameter] Label smoothing.
265
+ @param drop_overflow_error_text: If true, return None when the input exceeds the max length.
266
+ @param skip_overflow_error: If true, raise an error when the input exceeds the max length.
267
+ @param drop_highlight_error_text: If true, raise an error when a highlight span is not found in the paragraph.
268
+ @param use_auth_token: [optional] Huggingface transformers argument of `use_auth_token`
269
+ """
270
+
271
+ # take default model given the language
272
+ if model is None:
273
+ assert language in DEFAULT_MODELS.keys(),\
274
+ f"Model with language '{language}' is not available. Please choose language from " \
275
+ f"'{DEFAULT_MODELS.keys()}' or specify 'model'."
276
+ model = DEFAULT_MODELS[language]
277
+
278
+ # classify model type
279
+ self.is_qg = 'qg' in model.split('-') if is_qg is None else is_qg
280
+ self.is_ae = 'ae' in model.split('-') if is_ae is None else is_ae
281
+ self.is_qa = 'qa' in model.split('-') if is_qa is None else is_qa
282
+ self.is_qag = 'qag' in model.split('-') if is_qag is None else is_qag
283
+ # configs
284
+ self.model_name = model
285
+ self.max_length = max_length
286
+ self.max_length_output = max_length_output
287
+ self.label_smoothing = label_smoothing
288
+ self.drop_overflow_error_text = drop_overflow_error_text
289
+ self.skip_overflow_error = skip_overflow_error
290
+ self.drop_highlight_error_text = drop_highlight_error_text
291
+ self.drop_answer_error_text = drop_answer_error_text
292
+ self.model_name_ae = model_ae
293
+ self.max_length_ae = max_length_ae
294
+ self.max_length_output_ae = max_length_output_ae
295
+ # load model
296
+ self.tokenizer, self.model, config = load_language_model(
297
+ self.model_name, cache_dir=cache_dir, use_auth_token=use_auth_token, device_map=device_map,
298
+ torch_dtype=torch_dtype, low_cpu_mem_usage=low_cpu_mem_usage)
299
+ if 'add_prefix' not in config.to_dict().keys():
300
+ # this means the model is not fine-tuned
301
+ # assert add_prefix, '`add_prefix` is required for non-fine-tuned models'
302
+ self.add_prefix = add_prefix
303
+ else:
304
+ self.add_prefix = config.add_prefix
305
+
306
+ # set default behaviour for answer extraction
307
+ if self.model_name_ae is None:
308
+ self.model_name_ae = self.model_name if self.is_ae else "positionrank"
309
+ # load answer extraction model
310
+ self.answer_model_type = None
311
+ if self.model_name_ae in VALID_METHODS:
312
+ logging.info(f'use spaCy answer extraction model: {self.model_name_ae}')
313
+ self.tokenizer_ae = self.model_ae = self.add_prefix_ae = None
314
+ self.spacy_module = SpacyPipeline(language, self.model_name_ae)
315
+ self.answer_model_type = 'spacy'
316
+ else:
317
+ logging.info(f'use LMQG fine-tuned answer extraction model: {self.model_name_ae}')
318
+ if self.model_name == self.model_name_ae:
319
+ logging.info("the same model as QG is used as AE")
320
+ assert self.is_ae, f"the model ({self.model_name_ae}) is not fine-tuned for AE"
321
+ self.tokenizer_ae = self.model_ae = self.add_prefix_ae = None
322
+ self.answer_model_type = 'multitask'
323
+ else:
324
+ logging.info(f"loading 2nd model for AE: {self.model_name_ae}")
325
+ self.tokenizer_ae, self.model_ae, config_ae = load_language_model(model_ae, cache_dir=cache_dir, use_auth_token=use_auth_token)
326
+ self.add_prefix_ae = config_ae.add_prefix
327
+ self.answer_model_type = 'pipeline'
328
+ self.spacy_module = SpacyPipeline(language)
329
+
330
+ # GPU setup
331
+ self.device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
332
+ self.parallel = False
333
+ if torch.cuda.device_count() > 1:
334
+ self.parallel = True
335
+ self.model = torch.nn.DataParallel(self.model)
336
+ if self.model_ae is not None:
337
+ self.model_ae = torch.nn.DataParallel(self.model_ae)
338
+ self.model.to(self.device)
339
+ if self.model_ae is not None:
340
+ self.model_ae.to(self.device)
341
+ logging.info(f'Model `{self.model_name}`')
342
+ logging.info(f'\t * Num of GPU in use: {torch.cuda.device_count()}')
343
+ logging.info(f'\t * Prefix: {self.add_prefix}')
344
+ logging.info(f'\t * Language: {language} (ignore at the training phase)')
345
+
346
+ def push_to_hub(self, repo_id):
347
+ if self.parallel:
348
+ self.model.module.push_to_hub(repo_id)
349
+ else:
350
+ self.model.push_to_hub(repo_id)
351
+ self.tokenizer.push_to_hub(repo_id)
352
+
353
+ def generate_qa_end2end(self,
354
+ list_context: str or List,
355
+ batch_size: int = None,
356
+ num_beams: int = 4,
357
+ cache_path: str = None,
358
+ splitting_symbol: str = ' [SEP] ',
359
+ question_prefix: str = "question: ",
360
+ answer_prefix: str = ", answer: "):
361
+ """ Generate question from paragraph and answer. Note that `list_answer` is needed unless they are already
362
+ highlighted in the `list_context`. eg) "I live in <hl> Tokyo <hl>."
363
+
364
+ @param list_context: List of input texts.
365
+ @param batch_size: Batch size.
366
+ @param num_beams: Number of beam for model generation.
367
+ @param cache_path: Path to pre-compute features.
368
+ @return: List of generated sentences.
369
+ """
370
+ logging.info(f'running model for `question_answer_pair_generation`')
371
+ assert self.is_qag, "`generate_qa_end2end` is available for end2end_qag_model"
372
+ prefix_type = 'qag' if self.add_prefix else None
373
+ single_input = type(list_context) is str
374
+ list_context = [list_context] if single_input else list_context
375
+ output = self.generate_prediction(
376
+ list_context, prefix_type=prefix_type, cache_path=cache_path, num_beams=num_beams, batch_size=batch_size
377
+ )
378
+
379
+ def format_qa(list_raw_string):
380
+ tmp = []
381
+ for raw_string in list_raw_string:
382
+ if len(raw_string.split(answer_prefix)) != 2 or question_prefix not in raw_string:
383
+ logging.info(f"invalid prediction: {raw_string}")
384
+ else:
385
+ q, a = raw_string.split(answer_prefix)
386
+ a = re.sub(r'\A\s+', '', a)
387
+ a = re.sub(r'\s+\Z', '', a)
388
+ q = q.replace(question_prefix, "")
389
+ q = re.sub(r'\A\s+', '', q)
390
+ q = re.sub(r'\s+\Z', '', q)
391
+ tmp.append((q, a))
392
+ return tmp
393
+
394
+ output = [format_qa(o.split(splitting_symbol)) for o in output]
395
+ return output[0] if single_input else output
396
+
397
+ def generate_qa(self,
398
+ list_context: str or List,
399
+ batch_size: int = None,
400
+ num_beams: int = 4,
401
+ cache_path: str = None,
402
+ num_questions: int = None,
403
+ sentence_level: bool = False):
404
+ """ Generate question given context.
405
+
406
+ @param list_context: Input text.
407
+ @param batch_size: Batch size.
408
+ @param num_beams: Number of beam for model generation.
409
+ @param cache_path: Path to pre-compute features.
410
+ @param num_questions: Max number of questions.
411
+ @param sentence_level: Run prediction on each sentence of the context independently to reduce complexity.
412
+ @return: List of generated sentences.
413
+ """
414
+ if self.is_qag:
415
+ return self.generate_qa_end2end(list_context, batch_size, num_beams, cache_path)
416
+ single_input = type(list_context) is str
417
+ list_context = [list_context] if single_input else list_context
418
+ original_input_length = len(list_context)
419
+
420
+ logging.info('running model for `ae`')
421
+ list_answer = self.generate_a(
422
+ list_context,
423
+ batch_size=batch_size,
424
+ num_beams=num_beams,
425
+ cache_path=cache_path,
426
+ sentence_level=sentence_level,
427
+ num_questions=num_questions
428
+ )
429
+ valid_context_id = [n for n, a in enumerate(list_answer) if a is not None]
430
+ list_context = [list_context[n] for n in valid_context_id]
431
+ list_answer = [list_answer[n] for n in valid_context_id]
432
+ qg_input, qg_hl, list_length = [], [], [0]
433
+ for c, a in zip(list_context, list_answer):
434
+ qg_hl += a
435
+ qg_input += [c] * len(a)
436
+ list_length.append(list_length[-1] + len(a))
437
+ logging.info('running model for `qg`')
438
+ list_question = self.generate_q(
439
+ qg_input,
440
+ list_answer=qg_hl,
441
+ batch_size=batch_size,
442
+ cache_path=cache_path,
443
+ num_beams=num_beams,
444
+ sentence_level=sentence_level
445
+ )
446
+
447
+ assert len(qg_hl) == len(list_question), f"{len(qg_input)} != {len(list_question)}"
448
+
449
+ # return to nested list
450
+ list_question = [list_question[list_length[n - 1]:list_length[n]] for n in range(1, len(list_length))]
451
+ list_answer = [qg_hl[list_length[n - 1]:list_length[n]] for n in range(1, len(list_length))]
452
+ output_list = [None] * original_input_length
453
+ # print(len(valid_context_id), valid_context_id[:10], valid_context_id[-10:0])
454
+ # print(original_input_length)
455
+ # print(len(list_question), len(list_answer))
456
+ for n, _id in enumerate(valid_context_id):
457
+ output_list[_id] = [(q, a) for q, a in zip(list_question[n], list_answer[n])]
458
+ return output_list[0] if single_input else output_list
459
+
460
+ def generate_a(self,
461
+ context: str or List,
462
+ batch_size: int = None,
463
+ num_beams: int = 4,
464
+ cache_path: str = None,
465
+ sentence_level: bool = False,
466
+ num_questions: int = None):
467
+ """ Generate answers from each sentence.
468
+
469
+ @param context: Input text.
470
+ @param batch_size: Batch size.
471
+ @param num_beams: Number of beam for model generation.
472
+ @param cache_path: Path to pre-compute features.
473
+ @param sentence_level: Run prediction on each sentence of the context independently to reduce complexity.
474
+ @param num_questions: Max number of questions.
475
+ @return: List of generated answers.
476
+ """
477
+ logging.info(f'running model for `answer_extraction`')
478
+ if self.answer_model_type == 'spacy':
479
+ num_questions = 10 if num_questions is None else num_questions
480
+ if type(context) is str:
481
+ return self.spacy_module.keyword(context, num_questions)
482
+ else:
483
+ return [self.spacy_module.keyword(c, num_questions) for c in context]
484
+ single_input = type(context) is str
485
+ context = [context] if single_input else context
486
+ list_sentences = [self.spacy_module.sentence(c) for c in context] # split into sentence
487
+ list_inputs = [[c] * len(s) for c, s in zip(context, list_sentences)]
488
+ list_length = [0] + np.cumsum([len(s) for s in list_sentences]).tolist()
489
+ if sentence_level:
490
+ list_inputs = list_sentences
491
+ # flatten inputs
492
+ flat_sentences = list(chain(*list_sentences))
493
+ flat_inputs = list(chain(*list_inputs))
494
+ if self.answer_model_type == 'multitask':
495
+ answer = self.generate_prediction(
496
+ flat_inputs, # list_input,
497
+ highlights=flat_sentences, # highlights=list_sentence,
498
+ prefix_type='ae' if self.add_prefix else None,
499
+ cache_path=cache_path,
500
+ num_beams=num_beams,
501
+ batch_size=batch_size
502
+ )
503
+ elif self.answer_model_type == 'pipeline':
504
+ answer = self.generate_prediction(
505
+ flat_inputs, # list_input,
506
+ highlights=flat_sentences, # highlights=list_sentence,
507
+ prefix_type='ae' if self.add_prefix_ae else None,
508
+ cache_path=cache_path,
509
+ num_beams=num_beams,
510
+ batch_size=batch_size,
511
+ switch_to_model_ae=True
512
+ )
513
+ else:
514
+ raise ValueError(f"unknown answer model type: {self.answer_model_type}")
515
+ # return to nested list
516
+ answer = [clean(a) for a in answer]
517
+ list_answer = [answer[list_length[n - 1]:list_length[n]] for n in range(1, len(list_length))]
518
+ list_answer = [[a for a, c in zip(a_sent, c_sent) if a is not None and a in c]
519
+ for a_sent, c_sent in zip(list_answer, list_inputs)]
520
+ list_answer = [None if len(a) == 0 else a for a in list_answer]
521
+ if not self.drop_answer_error_text:
522
+ if any(a is None for a in list_answer):
523
+ raise AnswerNotFoundError([context[n] for n, a in enumerate(list_answer) if a is None][0])
524
+ return list_answer[0] if single_input else list_answer
525
+
526
+ def generate_q(self,
527
+ list_context: str or List,
528
+ list_answer: List = None,
529
+ batch_size: int = None,
530
+ num_beams: int = 4,
531
+ cache_path: str = None,
532
+ sentence_level: bool = False):
533
+ """ Generate question from paragraph and answer. Note that `list_answer` is needed unless they are already
534
+ highlighted in the `list_context`. eg) "I live in <hl> Tokyo <hl>."
535
+
536
+ @param list_context: List of input texts.
537
+ @param list_answer: List of answers in the `list_context` that are highlighted by <hl>.
538
+ @param batch_size: Batch size.
539
+ @param num_beams: Number of beam for model generation.
540
+ @param cache_path: Path to pre-compute features.
541
+ @param sentence_level: Run prediction on each sentence of the context independently to reduce complexity.
542
+ @return: List of generated sentences.
543
+ """
544
+ assert self.is_qg, "model is not fine-tuned for QG"
545
+ if list_answer is not None:
546
+ assert type(list_context) is type(list_answer), f"{type(list_context)} != {type(list_answer)}"
547
+ single_input = False
548
+ if type(list_context) is str:
549
+ list_context = [list_context]
550
+ list_answer = [list_answer] if list_answer is not None else None
551
+ single_input = True
552
+ output = self.generate_prediction(
553
+ list_context,
554
+ highlights=list_answer,
555
+ prefix_type='qg' if self.add_prefix else None,
556
+ cache_path=cache_path,
557
+ num_beams=num_beams,
558
+ batch_size=batch_size,
559
+ sentence_level=sentence_level
560
+ )
561
+ if single_input:
562
+ return output[0]
563
+ return output
564
+
565
+ def answer_q(self,
566
+ list_context: str or List,
567
+ list_question: str or List,
568
+ batch_size: int = None,
569
+ num_beams: int = 4,
570
+ cache_path: str = None):
571
+ logging.info(f'running model for `question_answering`')
572
+ assert self.is_qa, "model is not fine-tuned for QA"
573
+ assert type(list_context) is type(list_question), "invalid input"
574
+ single_input = type(list_context) is str
575
+ list_context = [list_context] if single_input else list_context
576
+ list_question = [list_question] if single_input else list_question
577
+ assert len(list_context) == len(list_question), f"invalid input: {len(list_context)} != {len(list_question)}"
578
+ output = self.generate_prediction(
579
+ [f"question: {q}, context: {c}" for q, c in zip(list_question, list_context)],
580
+ batch_size=batch_size,
581
+ prefix_type='qa' if self.add_prefix else None,
582
+ cache_path=cache_path,
583
+ num_beams=num_beams
584
+ )
585
+ return output[0] if single_input else output
586
+
587
+ def generate_prediction(self,
588
+ inputs: List,
589
+ highlights: List or None = None,
590
+ prefix_type: str = None,
591
+ num_beams: int = 4,
592
+ batch_size: int = None,
593
+ cache_path: str = None,
594
+ sentence_level: bool = False,
595
+ switch_to_model_ae: bool = False):
596
+ """ General method to generate model prediction
597
+
598
+ @param inputs: List of input sequences.
599
+ @param highlights: List of sub-sequences from list_context to be highlighted by <hl>.
600
+ @param batch_size: Batch size.
601
+ @param num_beams: Number of beam for model generation.
602
+ @param cache_path: Path to pre-compute features.
603
+ @param prefix_type: Either of `qg` or `answer_extraction`, which is to add at the beginning of the text.
604
+ @return: List of generated sequences.
605
+ """
606
+ self.eval()
607
+ if switch_to_model_ae:
608
+ assert self.model_ae is not None and self.tokenizer_ae is not None
609
+ model = self.model_ae
610
+ tokenizer = self.tokenizer_ae
611
+ max_length_output = self.max_length_output_ae
612
+ else:
613
+ model = self.model
614
+ tokenizer = self.tokenizer
615
+ max_length_output = self.max_length_output
616
+
617
+ if sentence_level:
618
+ assert highlights is not None, '`sentence_level` needs `highlights`.'
619
+ assert len(highlights) == len(inputs), str([len(highlights), len(inputs)])
620
+ list_sentence = []
621
+ for context, answer in zip(inputs, highlights):
622
+ s = [sentence for sentence in self.spacy_module.sentence(context) if answer in sentence]
623
+ list_sentence.append(s[0] if len(s) != 0 else context)
624
+ inputs = list_sentence
625
+
626
+ assert type(inputs) is list, inputs
627
+ encode_list = self.text_to_encode(
628
+ inputs,
629
+ highlights=highlights,
630
+ prefix_type=prefix_type,
631
+ cache_path=cache_path,
632
+ switch_to_model_ae=switch_to_model_ae
633
+ )
634
+ loader = self.get_data_loader(encode_list, batch_size=batch_size)
635
+ outputs = []
636
+ for encode in loader:
637
+ with torch.no_grad():
638
+ if 'labels' in encode:
639
+ encode.pop('labels')
640
+ encode = {k: v.to(self.device) for k, v in encode.items()}
641
+ encode['max_length'] = max_length_output
642
+ encode['num_beams'] = num_beams
643
+ tensor = model.module.generate(**encode) if self.parallel else model.generate(**encode)
644
+ outputs += tokenizer.batch_decode(tensor, skip_special_tokens=True)
645
+ return outputs
646
+
647
+ def encode_to_loss(self, encode: Dict):
648
+ """ Transform encoded features to loss value for model finetuning.
649
+
650
+ @param encode: Encoded feature.
651
+ @return: Loss value.
652
+ """
653
+ assert 'labels' in encode
654
+ output = self.model(**{k: v.to(self.device) for k, v in encode.items()})
655
+ if self.label_smoothing is None or self.label_smoothing == 0.0:
656
+ return output['loss'].mean() if self.parallel else output['loss']
657
+ else:
658
+ return label_smoothed_loss(output['logits'], encode['labels'].to(self.device), self.label_smoothing)
659
+
660
+ def text_to_encode(self,
661
+ inputs,
662
+ outputs: List = None,
663
+ highlights: List = None,
664
+ prefix_type: str = None,
665
+ cache_path: str = None,
666
+ switch_to_model_ae: bool = False):
667
+ """ Transform texts into encoded features.
668
+
669
+ @param inputs: List of input sequences.
670
+ @param outputs: List of output sequences.
671
+ @param highlights: List of sub-sequences from `inputs` to be highlighted by <hl>.
672
+ @param prefix_type: Either of `qg` or `answer_extraction`, which is to add at the beginning of the text.
673
+ @param cache_path: Path to pre-compute features.
674
+ @return: List of encoded feature.
675
+ """
676
+ if cache_path is not None and os.path.exists(cache_path):
677
+ logging.info(f'loading preprocessed feature from {cache_path}')
678
+ return pickle_load(cache_path)
679
+ outputs = [None] * len(inputs) if outputs is None else outputs
680
+ highlights = [None] * len(inputs) if highlights is None else highlights
681
+ assert len(outputs) == len(inputs) == len(highlights), str([len(outputs), len(inputs), len(highlights)])
682
+ data = list(zip(inputs, outputs, highlights))
683
+ # process in parallel/single
684
+ config = {'tokenizer': self.tokenizer, 'max_length': self.max_length, 'prefix_type': prefix_type,
685
+ 'max_length_output': self.max_length_output, 'drop_overflow_error_text': self.drop_overflow_error_text,
686
+ 'skip_overflow_error': self.skip_overflow_error, 'drop_highlight_error_text': self.drop_highlight_error_text,
687
+ 'padding': False if len(data) == 1 else True}
688
+ if switch_to_model_ae:
689
+ assert self.model_ae is not None and self.tokenizer_ae is not None
690
+ config['tokenizer'] = self.tokenizer_ae
691
+ config['max_length'] = self.max_length_ae
692
+ config['max_length_output'] = self.max_length_output_ae
693
+
694
+ logging.info(f'encode all the data : {len(data)}')
695
+ if cache_path is not None:
696
+ os.makedirs(os.path.dirname(cache_path), exist_ok=True)
697
+ if PARALLEL_PROCESSING:
698
+ pool = Pool()
699
+ out = pool.map(EncodePlus(**config), data)
700
+ pool.close()
701
+ out = list(filter(None, out)) # remove overflow text
702
+ else:
703
+ f = EncodePlus(**config)
704
+ out = []
705
+ files = []
706
+ for i in tqdm(data):
707
+ e = f(i)
708
+ if e is not None: # remove overflow text
709
+ out.append(e)
710
+ if len(out) > 40000 and cache_path is not None:
711
+ pickle_save(out, f'{cache_path}.tmp{len(files)}')
712
+ files.append(f'{cache_path}.tmp{len(files)}')
713
+ out = []
714
+ if len(out) > 0 and cache_path is not None:
715
+ pickle_save(out, f'{cache_path}.tmp{len(files)}')
716
+ files.append(f'{cache_path}.tmp{len(files)}')
717
+ if len(files) > 0:
718
+ out = list(chain(*[pickle_load(i) for i in files]))
719
+ logging.info(f'after remove the overflow : {len(out)}')
720
+ # cache the encoded data
721
+ if cache_path is not None:
722
+ pickle_save(out, cache_path)
723
+ logging.info(f'preprocessed feature is saved at {cache_path}')
724
+ return out
725
+
726
+ def save(self, save_dir):
727
+ """ Save model.
728
+
729
+ @param save_dir: Directory to save model related file.
730
+ """
731
+
732
+ def model_state(model):
733
+ if self.parallel:
734
+ return model.module
735
+ return model
736
+
737
+ logging.info('saving model')
738
+ model_state(self.model).config.update({'add_prefix': self.add_prefix})
739
+ model_state(self.model).save_pretrained(save_dir)
740
+ logging.info('saving tokenizer')
741
+ self.tokenizer.save_pretrained(save_dir)
742
+
743
+ @staticmethod
744
+ def get_data_loader(encode_list, batch_size: int = None, shuffle: bool = False, drop_last: bool = False):
745
+ """ Get torch.utils.data.DataLoader instance.
746
+
747
+ @param encode_list: List of encoded features.
748
+ @param batch_size: Batch size.
749
+ @param shuffle: Shuffle data.
750
+ @param drop_last: Drop residual batch.
751
+ @return: torch.utils.data.DataLoader
752
+ """
753
+ batch_size = len(encode_list) if batch_size is None else batch_size
754
+ params = dict(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=NUM_WORKERS)
755
+ return torch.utils.data.DataLoader(Dataset(encode_list), **params)
756
+
757
+ def train(self):
758
+ self.model.train()
759
+
760
+ def eval(self):
761
+ self.model.eval()