Casually commited on
Commit
eba33b5
·
1 Parent(s): b59ac93

Upload modeling_uie.py

Browse files
Files changed (1) hide show
  1. modeling_uie.py +710 -0
modeling_uie.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import math
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, List, Union, Dict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from transformers import ErnieModel, ErniePreTrainedModel, PretrainedConfig, PreTrainedTokenizerFast
11
+ from transformers.utils import ModelOutput
12
+
13
+
14
+ @dataclass
15
+ class UIEModelOutput(ModelOutput):
16
+ """
17
+ Output class for outputs of UIE.
18
+ Args:
19
+ loss (`torch.FloatTensor` of shape `(1),`, *optional*, returned when `labels` is provided):
20
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
21
+ start_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
22
+ Span-start scores (after Sigmoid).
23
+ end_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
24
+ Span-end scores (after Sigmoid).
25
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
26
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding
27
+ layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
28
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
29
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
30
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
31
+ sequence_length)`.
32
+ Attentions weights after the attention softmax, used to compute the weighted average in the
33
+ self-attention heads.
34
+ """
35
+ loss: Optional[torch.FloatTensor] = None
36
+ start_prob: torch.FloatTensor = None
37
+ end_prob: torch.FloatTensor = None
38
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
39
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
40
+
41
+
42
+ class UIE(ErniePreTrainedModel):
43
+ """
44
+ UIE model based on Bert model.
45
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
46
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
47
+ etc.)
48
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
49
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
50
+ and behavior.
51
+ Parameters:
52
+ config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model.
53
+ Initializing with a config file does not load the weights associated with the model, only the
54
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
55
+ """
56
+
57
+ def __init__(self, config: PretrainedConfig):
58
+ super(UIE, self).__init__(config)
59
+ self.encoder = ErnieModel(config)
60
+ self.config = config
61
+ hidden_size = self.config.hidden_size
62
+
63
+ self.linear_start = nn.Linear(hidden_size, 1)
64
+ self.linear_end = nn.Linear(hidden_size, 1)
65
+ self.sigmoid = nn.Sigmoid()
66
+
67
+ self.post_init()
68
+
69
+ def forward(self, input_ids: Optional[torch.Tensor] = None,
70
+ token_type_ids: Optional[torch.Tensor] = None,
71
+ position_ids: Optional[torch.Tensor] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ head_mask: Optional[torch.Tensor] = None,
74
+ inputs_embeds: Optional[torch.Tensor] = None,
75
+ start_positions: Optional[torch.Tensor] = None,
76
+ end_positions: Optional[torch.Tensor] = None,
77
+ output_attentions: Optional[bool] = None,
78
+ output_hidden_states: Optional[bool] = None,
79
+ return_dict: Optional[bool] = None
80
+ ):
81
+ """
82
+ Args:
83
+ input_ids (`torch.LongTensor` of shape `({0})`):
84
+ Indices of input sequence tokens in the vocabulary.
85
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
86
+ [`PreTrainedTokenizer.__call__`] for details.
87
+ [What are input IDs?](../glossary#input-ids)
88
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
89
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
90
+ - 1 for tokens that are **not masked**,
91
+ - 0 for tokens that are **masked**.
92
+ [What are attention masks?](../glossary#attention-mask)
93
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
94
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
95
+ 1]`:
96
+ - 0 corresponds to a *sentence A* token,
97
+ - 1 corresponds to a *sentence B* token.
98
+ [What are token type IDs?](../glossary#token-type-ids)
99
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
100
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
101
+ config.max_position_embeddings - 1]`.
102
+ [What are position IDs?](../glossary#position-ids)
103
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
104
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
105
+ - 1 indicates the head is **not masked**,
106
+ - 0 indicates the head is **masked**.
107
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
108
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
109
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
110
+ model's internal embedding lookup matrix.
111
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
112
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
113
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
114
+ are not taken into account for computing the loss.
115
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
116
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
117
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
118
+ are not taken into account for computing the loss.
119
+ output_attentions (`bool`, *optional*):
120
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
121
+ tensors for more detail.
122
+ output_hidden_states (`bool`, *optional*):
123
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
124
+ more detail.
125
+ return_dict (`bool`, *optional*):
126
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
127
+ """
128
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
129
+ outputs = self.encoder(
130
+ input_ids=input_ids,
131
+ token_type_ids=token_type_ids,
132
+ position_ids=position_ids,
133
+ attention_mask=attention_mask,
134
+ head_mask=head_mask,
135
+ inputs_embeds=inputs_embeds,
136
+ output_attentions=output_attentions,
137
+ output_hidden_states=output_hidden_states,
138
+ return_dict=return_dict
139
+ )
140
+ sequence_output = outputs[0]
141
+
142
+ start_logits = self.linear_start(sequence_output)
143
+ start_logits = torch.squeeze(start_logits, -1)
144
+ start_prob = self.sigmoid(start_logits)
145
+ end_logits = self.linear_end(sequence_output)
146
+ end_logits = torch.squeeze(end_logits, -1)
147
+ end_prob = self.sigmoid(end_logits)
148
+
149
+ total_loss = None
150
+ if start_positions is not None and end_positions is not None:
151
+ loss_fct = nn.BCELoss()
152
+ start_loss = loss_fct(start_prob, start_positions)
153
+ end_loss = loss_fct(end_prob, end_positions)
154
+ total_loss = (start_loss + end_loss) / 2.0
155
+
156
+ if not return_dict:
157
+ output = (start_prob, end_prob) + outputs[2:]
158
+ return ((total_loss,) + output) if total_loss is not None else output
159
+
160
+ return UIEModelOutput(
161
+ loss=total_loss,
162
+ start_prob=start_prob,
163
+ end_prob=end_prob,
164
+ hidden_states=outputs.hidden_states,
165
+ attentions=outputs.attentions,
166
+ )
167
+
168
+ def predict(self, schema: Union[Dict, List[str], str], input_texts: Union[List[str], str],
169
+ tokenizer: PreTrainedTokenizerFast, max_length: int = 512, batch_size: int = 32,
170
+ position_prob: int = 0.5, progress_hook=None) -> List[Dict]:
171
+ """
172
+
173
+ Args:
174
+ schema (Union[Dict, List[str], str]): 抽取目标
175
+ input_texts (input_texts: Union[List[str], str]): 待抽取文本
176
+ tokenizer (PreTrainedTokenizerFast):
177
+ max_length (int):
178
+ batch_size (int):
179
+ position_prob (float):
180
+ progress_hook:
181
+
182
+ Returns:
183
+ result (List[Dict]):
184
+ """
185
+
186
+ predictor = UIEPredictor(self, tokenizer=tokenizer, schema=schema, max_length=max_length,
187
+ position_prob=position_prob, batch_size=batch_size, hook=progress_hook)
188
+ input_texts = [input_texts] if isinstance(input_texts, str) else input_texts
189
+ return predictor.predict(input_texts)
190
+
191
+
192
+ class UIEPredictor(object):
193
+ def __init__(self, model, tokenizer, schema, max_length=512, position_prob=0.5, batch_size=32, hook=None):
194
+ self.model = model
195
+ self._tokenizer = tokenizer
196
+
197
+ self._position_prob = position_prob
198
+ self.max_length = max_length
199
+ self._batch_size = batch_size
200
+ self._multilingual = getattr(self.model.config, 'multilingual', False)
201
+ self._schema_tree = self.set_schema(schema)
202
+ self._hook = hook
203
+
204
+ def set_schema(self, schema):
205
+ if isinstance(schema, dict) or isinstance(schema, str):
206
+ schema = [schema]
207
+ return self._build_tree(schema)
208
+
209
+ @classmethod
210
+ def _build_tree(cls, schema, name="root"):
211
+ """
212
+ Build the schema tree.
213
+ """
214
+ schema_tree = SchemaTree(name)
215
+ for s in schema:
216
+ if isinstance(s, str):
217
+ schema_tree.add_child(SchemaTree(s))
218
+ elif isinstance(s, dict):
219
+ for k, v in s.items():
220
+ if isinstance(v, str):
221
+ child = [v]
222
+ elif isinstance(v, list):
223
+ child = v
224
+ else:
225
+ raise TypeError(
226
+ "Invalid schema, value for each key:value pairs should be list or string"
227
+ "but {} received".format(type(v))
228
+ )
229
+ schema_tree.add_child(cls._build_tree(child, name=k))
230
+ else:
231
+ raise TypeError("Invalid schema, element should be string or dict, " "but {} received".format(type(s)))
232
+ return schema_tree
233
+
234
+ def _single_stage_predict(self, inputs):
235
+ input_texts = []
236
+ prompts = []
237
+ for i in range(len(inputs)):
238
+ input_texts.append(inputs[i]["text"])
239
+ prompts.append(inputs[i]["prompt"])
240
+ # max predict length should exclude the length of prompt and summary tokens
241
+ max_predict_len = self.max_length - len(max(prompts)) - 3
242
+ short_input_texts, self.input_mapping = Utils.auto_splitter(input_texts, max_predict_len, split_sentence=False)
243
+
244
+ short_texts_prompts = []
245
+ for k, v in self.input_mapping.items():
246
+ short_texts_prompts.extend([prompts[k] for _ in range(len(v))])
247
+ short_inputs = [
248
+ {"text": short_input_texts[i], "prompt": short_texts_prompts[i]} for i in range(len(short_input_texts))
249
+ ]
250
+
251
+ prompts = []
252
+ texts = []
253
+ for s in short_inputs:
254
+ prompts.append(s["prompt"])
255
+ texts.append(s["text"])
256
+
257
+ if self._multilingual:
258
+ padding_type = "max_length"
259
+ else:
260
+ padding_type = "longest"
261
+
262
+ encoded_inputs = self._tokenizer(
263
+ text=prompts,
264
+ text_pair=texts,
265
+ stride=2,
266
+ truncation=True,
267
+ max_length=self.max_length,
268
+ padding=padding_type,
269
+ add_special_tokens=True,
270
+ return_offsets_mapping=True,
271
+ return_tensors="np")
272
+
273
+ offset_maps = encoded_inputs["offset_mapping"]
274
+ start_probs = []
275
+ end_probs = []
276
+ for idx in range(0, len(texts), self._batch_size):
277
+ l, r = idx, idx + self._batch_size
278
+
279
+ input_ids = encoded_inputs["input_ids"][l:r]
280
+ token_type_ids = encoded_inputs["token_type_ids"][l:r]
281
+ attention_mask = encoded_inputs["attention_mask"][l:r]
282
+
283
+ if self._multilingual:
284
+ input_ids = np.array(
285
+ input_ids, dtype="int64")
286
+ attention_mask = np.array(
287
+ attention_mask, dtype="int64")
288
+ position_ids = (np.cumsum(np.ones_like(input_ids), axis=1)
289
+ - np.ones_like(input_ids)) * attention_mask
290
+ input_dict = {
291
+ "input_ids": input_ids,
292
+ "attention_mask": attention_mask,
293
+ "position_ids": position_ids
294
+ }
295
+ else:
296
+ input_dict = {
297
+ "input_ids": np.array(
298
+ input_ids, dtype="int64"),
299
+ "token_type_ids": np.array(
300
+ token_type_ids, dtype="int64"),
301
+ "attention_mask": np.array(
302
+ attention_mask, dtype="int64")
303
+ }
304
+
305
+ start_prob, end_prob = self._infer(input_dict)
306
+ start_prob = start_prob.tolist()
307
+ end_prob = end_prob.tolist()
308
+ start_probs.extend(start_prob)
309
+ end_probs.extend(end_prob)
310
+ if self._hook is not None:
311
+ self._hook.update(1)
312
+ start_ids_list = Utils.get_bool_ids_greater_than(start_probs, limit=self._position_prob, return_prob=True)
313
+ end_ids_list = Utils.get_bool_ids_greater_than(end_probs, limit=self._position_prob, return_prob=True)
314
+ sentence_ids = []
315
+ probs = []
316
+ for start_ids, end_ids, offset_map in zip(start_ids_list, end_ids_list, offset_maps.tolist()):
317
+ span_list = Utils.get_span(start_ids, end_ids, with_prob=True)
318
+ sentence_id, prob = Utils.get_id_and_prob(span_list, offset_map)
319
+ sentence_ids.append(sentence_id)
320
+ probs.append(prob)
321
+ results = Utils.convert_ids_to_results(short_inputs, sentence_ids, probs)
322
+ results = Utils.auto_joiner(results, short_input_texts, self.input_mapping)
323
+ return results
324
+
325
+ def _multi_stage_predict(self, data):
326
+ """
327
+ Traversal the schema tree and do multi-stage prediction.
328
+ Args:
329
+ data (list): a list of strings
330
+ Returns:
331
+ list: a list of predictions, where the list's length
332
+ equals to the length of `data`
333
+ """
334
+ results = [{} for _ in range(len(data))]
335
+ # input check to early return
336
+ if len(data) < 1 or self._schema_tree is None:
337
+ return results
338
+
339
+ _pre_node_total = len(data) // self._batch_size + (1 if len(data) % self._batch_size else 0)
340
+ _finish_node = 0
341
+ if self._hook is not None:
342
+ self._hook.reset(total=self._schema_tree.shape * _pre_node_total)
343
+
344
+ # copy to stay `self._schema_tree` unchanged
345
+ schema_list = self._schema_tree.children[:]
346
+ while len(schema_list) > 0:
347
+ node = schema_list.pop(0)
348
+ examples = []
349
+ input_map = {}
350
+ cnt = 0
351
+ idx = 0
352
+ if not node.prefix:
353
+ for one_data in data:
354
+ examples.append({"text": one_data, "prompt": Utils.dbc2sbc(node.name)})
355
+ input_map[cnt] = [idx]
356
+ idx += 1
357
+ cnt += 1
358
+ else:
359
+ for pre, one_data in zip(node.prefix, data):
360
+ if len(pre) == 0:
361
+ input_map[cnt] = []
362
+ else:
363
+ for p in pre:
364
+ examples.append({"text": one_data, "prompt": Utils.dbc2sbc(p + node.name)})
365
+ input_map[cnt] = [i + idx for i in range(len(pre))]
366
+ idx += len(pre)
367
+ cnt += 1
368
+ if len(examples) == 0:
369
+ result_list = []
370
+ else:
371
+ result_list = self._single_stage_predict(examples)
372
+
373
+ if not node.parent_relations:
374
+ relations = [[] for _ in range(len(data))]
375
+ for k, v in input_map.items():
376
+ for idx in v:
377
+ if len(result_list[idx]) == 0:
378
+ continue
379
+ if node.name not in results[k].keys():
380
+ results[k][node.name] = result_list[idx]
381
+ else:
382
+ results[k][node.name].extend(result_list[idx])
383
+ if node.name in results[k].keys():
384
+ relations[k].extend(results[k][node.name])
385
+ else:
386
+ relations = node.parent_relations
387
+ for k, v in input_map.items():
388
+ for i in range(len(v)):
389
+ if len(result_list[v[i]]) == 0:
390
+ continue
391
+ if "relations" not in relations[k][i].keys():
392
+ relations[k][i]["relations"] = {node.name: result_list[v[i]]}
393
+ elif node.name not in relations[k][i]["relations"].keys():
394
+ relations[k][i]["relations"][node.name] = result_list[v[i]]
395
+ else:
396
+ relations[k][i]["relations"][node.name].extend(result_list[v[i]])
397
+ new_relations = [[] for _ in range(len(data))]
398
+ for i in range(len(relations)):
399
+ for j in range(len(relations[i])):
400
+ if "relations" in relations[i][j].keys() and node.name in relations[i][j]["relations"].keys():
401
+ for k in range(len(relations[i][j]["relations"][node.name])):
402
+ new_relations[i].append(relations[i][j]["relations"][node.name][k])
403
+ relations = new_relations
404
+
405
+ prefix = [[] for _ in range(len(data))]
406
+ for k, v in input_map.items():
407
+ for idx in v:
408
+ for i in range(len(result_list[idx])):
409
+ prefix[k].append(result_list[idx][i]["text"] + "的")
410
+ for child in node.children:
411
+ child.prefix = prefix
412
+ child.parent_relations = relations
413
+ schema_list.append(child)
414
+ _finish_node += 1
415
+ if self._hook is not None:
416
+ self._hook.n = _finish_node * _pre_node_total
417
+ if self._hook is not None:
418
+ self._hook.close()
419
+ return results
420
+
421
+ def _infer(self, input_dict):
422
+ for input_name, input_value in input_dict.items():
423
+ input_dict[input_name] = torch.LongTensor(input_value).to(self.model.device)
424
+ outputs = self.model(**input_dict)
425
+ return outputs.start_prob.detach().cpu().numpy(), outputs.end_prob.detach().cpu().numpy()
426
+
427
+ def predict(self, input_data):
428
+ results = self._multi_stage_predict(data=input_data)
429
+ return results
430
+
431
+
432
+ class SchemaTree(object):
433
+ """
434
+ Implementataion of SchemaTree
435
+ """
436
+
437
+ def __init__(self, name="root", children=None):
438
+ self.name = name
439
+ self.children = []
440
+ self.prefix = None
441
+ self.parent_relations = None
442
+ if children is not None:
443
+ for child in children:
444
+ self.add_child(child)
445
+ self._total_nodes = 0
446
+
447
+ @property
448
+ def shape(self):
449
+ return len(self.children) + sum([child.shape for child in self.children])
450
+
451
+ def __repr__(self):
452
+ return self.name
453
+
454
+ def add_child(self, node):
455
+ assert isinstance(node, SchemaTree), "The children of a node should be an instacne of SchemaTree."
456
+ self._total_nodes += 1
457
+ self.children.append(node)
458
+
459
+
460
+ class Utils:
461
+
462
+ @classmethod
463
+ def dbc2sbc(cls, s):
464
+ rs = ""
465
+ for char in s:
466
+ code = ord(char)
467
+ if code == 0x3000:
468
+ code = 0x0020
469
+ else:
470
+ code -= 0xFEE0
471
+ if not (0x0021 <= code <= 0x7E):
472
+ rs += char
473
+ continue
474
+ rs += chr(code)
475
+ return rs
476
+
477
+ @classmethod
478
+ def cut_chinese_sent(cls, para):
479
+ """
480
+ Cut the Chinese sentences more precisely, reference to
481
+ "https://blog.csdn.net/blmoistawinde/article/details/82379256".
482
+ """
483
+ para = re.sub(r'([。!??])([^”’])', r"\1\n\2", para) # 单字符断句符
484
+ para = re.sub(r'(\.{6})([^”’])', r"\1\n\2", para) # 英文省略号
485
+ para = re.sub(r'(…{2})([^”’])', r"\1\n\2", para) # 中文省略号
486
+ para = re.sub(r'([。!??][”’])([^,。!??])', r'\1\n\2', para)
487
+ para = para.rstrip()
488
+ return para.split("\n")
489
+
490
+ @classmethod
491
+ def get_bool_ids_greater_than(cls, probs, limit=0.5, return_prob=False):
492
+ """
493
+ Get idx of the last dimension in probability arrays, which is greater than a limitation.
494
+
495
+ Args:
496
+ probs (List[List[float]]): The input probability arrays.
497
+ limit (float): The limitation for probability.
498
+ return_prob (bool): Whether to return the probability
499
+ Returns:
500
+ List[List[int]]: The index of the last dimension meet the conditions.
501
+ """
502
+ probs = np.array(probs)
503
+ dim_len = len(probs.shape)
504
+ if dim_len > 1:
505
+ result = []
506
+ for p in probs:
507
+ result.append(cls.get_bool_ids_greater_than(p, limit, return_prob))
508
+ return result
509
+ else:
510
+ result = []
511
+ for i, p in enumerate(probs):
512
+ if p > limit:
513
+ if return_prob:
514
+ result.append((i, p))
515
+ else:
516
+ result.append(i)
517
+ return result
518
+
519
+ @classmethod
520
+ def get_span(cls, start_ids, end_ids, with_prob=False):
521
+ """
522
+ Get span set from position start and end list.
523
+
524
+ Args:
525
+ start_ids (List[int]/List[tuple]): The start index list.
526
+ end_ids (List[int]/List[tuple]): The end index list.
527
+ with_prob (bool): If True, each element for start_ids and end_ids is a tuple as like: (index, probability).
528
+ Returns:
529
+ set: The span set without overlapping, every id can only be used once .
530
+ """
531
+ if with_prob:
532
+ start_ids = sorted(start_ids, key=lambda x: x[0])
533
+ end_ids = sorted(end_ids, key=lambda x: x[0])
534
+ else:
535
+ start_ids = sorted(start_ids)
536
+ end_ids = sorted(end_ids)
537
+
538
+ start_pointer = 0
539
+ end_pointer = 0
540
+ len_start = len(start_ids)
541
+ len_end = len(end_ids)
542
+ couple_dict = {}
543
+ while start_pointer < len_start and end_pointer < len_end:
544
+ if with_prob:
545
+ start_id = start_ids[start_pointer][0]
546
+ end_id = end_ids[end_pointer][0]
547
+ else:
548
+ start_id = start_ids[start_pointer]
549
+ end_id = end_ids[end_pointer]
550
+
551
+ if start_id == end_id:
552
+ couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
553
+ start_pointer += 1
554
+ end_pointer += 1
555
+ continue
556
+ if start_id < end_id:
557
+ couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
558
+ start_pointer += 1
559
+ continue
560
+ if start_id > end_id:
561
+ end_pointer += 1
562
+ continue
563
+ result = [(couple_dict[end], end) for end in couple_dict]
564
+ result = set(result)
565
+ return result
566
+
567
+ @classmethod
568
+ def get_id_and_prob(cls, span_set, offset_mapping: np.array):
569
+ """
570
+ Return text id and probability of predicted spans
571
+
572
+ Args:
573
+ span_set (set): set of predicted spans.
574
+ offset_mapping (numpy.array): list of pair preserving the
575
+ index of start and end char in original text pair (prompt + text) for each token.
576
+ Returns:
577
+ sentence_id (list[tuple]): index of start and end char in original text.
578
+ prob (list[float]): probabilities of predicted spans.
579
+ """
580
+ prompt_end_token_id = offset_mapping[1:].index([0, 0])
581
+ bias = offset_mapping[prompt_end_token_id][1] + 1
582
+ for index in range(1, prompt_end_token_id + 1):
583
+ offset_mapping[index][0] -= bias
584
+ offset_mapping[index][1] -= bias
585
+
586
+ sentence_id = []
587
+ prob = []
588
+ for start, end in span_set:
589
+ prob.append(start[1] * end[1])
590
+ start_id = offset_mapping[start[0]][0]
591
+ end_id = offset_mapping[end[0]][1]
592
+ sentence_id.append((start_id, end_id))
593
+ return sentence_id, prob
594
+
595
+ @classmethod
596
+ def auto_splitter(cls, input_texts, max_text_len, split_sentence=False):
597
+ """
598
+ Split the raw texts automatically for model inference.
599
+ Args:
600
+ input_texts (List[str]): input raw texts.
601
+ max_text_len (int): cutting length.
602
+ split_sentence (bool): If True, sentence-level split will be performed.
603
+ return:
604
+ short_input_texts (List[str]): the short input texts for model inference.
605
+ input_mapping (dict): mapping between raw text and short input texts.
606
+ """
607
+ input_mapping = {}
608
+ short_input_texts = []
609
+ cnt_org = 0
610
+ cnt_short = 0
611
+ for text in input_texts:
612
+ if not split_sentence:
613
+ sens = [text]
614
+ else:
615
+ sens = Utils.cut_chinese_sent(text)
616
+ for sen in sens:
617
+ lens = len(sen)
618
+ if lens <= max_text_len:
619
+ short_input_texts.append(sen)
620
+ if cnt_org not in input_mapping.keys():
621
+ input_mapping[cnt_org] = [cnt_short]
622
+ else:
623
+ input_mapping[cnt_org].append(cnt_short)
624
+ cnt_short += 1
625
+ else:
626
+ temp_text_list = [sen[i: i + max_text_len] for i in range(0, lens, max_text_len)]
627
+ short_input_texts.extend(temp_text_list)
628
+ short_idx = cnt_short
629
+ cnt_short += math.ceil(lens / max_text_len)
630
+ temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)]
631
+ if cnt_org not in input_mapping.keys():
632
+ input_mapping[cnt_org] = temp_text_id
633
+ else:
634
+ input_mapping[cnt_org].extend(temp_text_id)
635
+ cnt_org += 1
636
+ return short_input_texts, input_mapping
637
+
638
+ @classmethod
639
+ def convert_ids_to_results(cls, examples, sentence_ids, probs):
640
+ """
641
+ Convert ids to raw text in a single stage.
642
+ """
643
+ results = []
644
+ for example, sentence_id, prob in zip(examples, sentence_ids, probs):
645
+ if len(sentence_id) == 0:
646
+ results.append([])
647
+ continue
648
+ result_list = []
649
+ text = example["text"]
650
+ prompt = example["prompt"]
651
+ for i in range(len(sentence_id)):
652
+ start, end = sentence_id[i]
653
+ if start < 0 and end >= 0:
654
+ continue
655
+ if end < 0:
656
+ start += len(prompt) + 1
657
+ end += len(prompt) + 1
658
+ result = {"text": prompt[start:end], "probability": prob[i]}
659
+ result_list.append(result)
660
+ else:
661
+ result = {"text": text[start:end], "start": start, "end": end, "probability": prob[i]}
662
+ result_list.append(result)
663
+ results.append(result_list)
664
+ return results
665
+
666
+ @classmethod
667
+ def auto_joiner(cls, short_results, short_inputs, input_mapping):
668
+ concat_results = []
669
+ is_cls_task = False
670
+ for short_result in short_results:
671
+ if not short_result:
672
+ continue
673
+ elif "start" not in short_result[0].keys() and "end" not in short_result[0].keys():
674
+ is_cls_task = True
675
+ break
676
+ else:
677
+ break
678
+ for k, vs in input_mapping.items():
679
+ if is_cls_task:
680
+ cls_options = {}
681
+ for v in vs:
682
+ if len(short_results[v]) == 0:
683
+ continue
684
+ if short_results[v][0]["text"] not in cls_options.keys():
685
+ cls_options[short_results[v][0]["text"]] = [1, short_results[v][0]["probability"]]
686
+ else:
687
+ cls_options[short_results[v][0]["text"]][0] += 1
688
+ cls_options[short_results[v][0]["text"]][1] += short_results[v][0]["probability"]
689
+ if len(cls_options) != 0:
690
+ cls_res, cls_info = max(cls_options.items(), key=lambda x: x[1])
691
+ concat_results.append([{"text": cls_res, "probability": cls_info[1] / cls_info[0]}])
692
+ else:
693
+ concat_results.append([])
694
+ else:
695
+ offset = 0
696
+ single_results = []
697
+ for v in vs:
698
+ if v == 0:
699
+ single_results = short_results[v]
700
+ offset += len(short_inputs[v])
701
+ else:
702
+ for i in range(len(short_results[v])):
703
+ if "start" not in short_results[v][i] or "end" not in short_results[v][i]:
704
+ continue
705
+ short_results[v][i]["start"] += offset
706
+ short_results[v][i]["end"] += offset
707
+ offset += len(short_inputs[v])
708
+ single_results.extend(short_results[v])
709
+ concat_results.append(single_results)
710
+ return concat_results