MrPotato commited on
Commit
ad51607
·
1 Parent(s): 246ece9

commit files to HF hub

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "MrPotato/ref-seg-ger_large_tokenized",
3
+ "alpha": 0.5,
4
+ "architectures": [
5
+ "XLMRobertaForReferenceSegmentation"
6
+ ],
7
+ "attention_probs_dropout_prob": 0.1,
8
+ "bos_token_id": 0,
9
+ "classifier_dropout": null,
10
+ "custom_pipelines": {
11
+ "ref-seg": {
12
+ "impl": "ref_seg.RefSegPipeline",
13
+ "pt": [
14
+ "AutoModelForTokenClassification"
15
+ ],
16
+ "tf": [
17
+ "TFAutoModelForTokenClassification"
18
+ ]
19
+ }
20
+ },
21
+ "eos_token_id": 2,
22
+ "hidden_act": "gelu",
23
+ "hidden_dropout_prob": 0.1,
24
+ "hidden_size": 1024,
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": 4096,
27
+ "layer_norm_eps": 1e-05,
28
+ "max_position_embeddings": 514,
29
+ "model_type": "xlm-roberta",
30
+ "num_attention_heads": 16,
31
+ "num_hidden_layers": 24,
32
+ "num_labels_first": 29,
33
+ "num_labels_second": 2,
34
+ "pad_token_id": 1,
35
+ "position_embedding_type": "absolute",
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.25.1",
38
+ "type_vocab_size": 1,
39
+ "use_cache": true,
40
+ "vocab_size": 250002
41
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc7dbcc9cd8cad6d81cba90b6b3e510410adfb3c9a8ab28fbca81708bd63688c
3
+ size 2235624885
ref_seg.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, XLMRobertaForTokenClassification, Pipeline, AutoModelForTokenClassification, AutoModel, XLMRobertaTokenizerFast
2
+ from tokenizers.pre_tokenizers import Whitespace
3
+ from transformers.pipelines import PIPELINE_REGISTRY
4
+ from itertools import chain
5
+ from colorama import Fore, Back
6
+ from colorama import Style
7
+ import numpy as np
8
+ from transformers.models.xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
9
+ from transformers.models.roberta import RobertaConfig
10
+ from transformers.modeling_outputs import TokenClassifierOutput
11
+ from transformers import PretrainedConfig
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ class RefSegPipeline(Pipeline):
18
+
19
+ labels = [
20
+ 'publisher', 'source', 'url', 'other', 'author', 'editor', 'lpage',
21
+ 'volume', 'year', 'issue', 'title', 'fpage', 'edition'
22
+ ]
23
+ iob_labels = list(chain.from_iterable([['B-' + x, 'I-' + x] for x in labels])) + ['O']
24
+ id2seg = {k: v for k, v in enumerate(iob_labels)}
25
+ id2ref = {k: v for k, v in enumerate(['B-ref', 'I-ref', ])}
26
+
27
+ def _sanitize_parameters(self, **kwargs):
28
+ if "id2seg" in kwargs:
29
+ self.id2seg = kwargs["id2seg"]
30
+ if "id2ref" in kwargs:
31
+ self.id2ref = kwargs["id2ref"]
32
+ return {}, {}, {}
33
+
34
+ def preprocess(self, sentence, offset_mapping=None):
35
+ model_inputs = self.tokenizer(
36
+ sentence,
37
+ return_offsets_mapping=True,
38
+ padding='max_length',
39
+ truncation=True,
40
+ max_length=512,
41
+ return_tensors="pt",
42
+ return_special_tokens_mask=True,
43
+ return_overflowing_tokens=True
44
+ )
45
+
46
+ if offset_mapping:
47
+ model_inputs["offset_mapping"] = offset_mapping
48
+
49
+ model_inputs["sentence"] = sentence
50
+
51
+ return model_inputs
52
+
53
+ def _forward(self, model_inputs):
54
+ special_tokens_mask = model_inputs.pop("special_tokens_mask")
55
+ offset_mapping = model_inputs.pop("offset_mapping", None)
56
+ sentence = model_inputs.pop("sentence")
57
+ overflow_mapping = model_inputs.pop("overflow_to_sample_mapping")
58
+ if self.framework == "tf":
59
+ logits = self.model(model_inputs.data)[0]
60
+ else:
61
+ logits = self.model(**model_inputs)[0]
62
+
63
+ return {
64
+ "logits": logits,
65
+ "special_tokens_mask": special_tokens_mask,
66
+ "offset_mapping": offset_mapping,
67
+ "overflow_mapping": overflow_mapping,
68
+ "sentence": sentence,
69
+ **model_inputs,
70
+ }
71
+
72
+ def postprocess(self, model_outputs):
73
+ # if ignore_labels is None:
74
+ ignore_labels = ["O"]
75
+ logits_seg = model_outputs["logits"][0].numpy()
76
+ logits_ref = model_outputs["logits"][1].numpy()
77
+ sentence = model_outputs["sentence"]
78
+ input_ids = model_outputs["input_ids"]
79
+ special_tokens_mask = model_outputs["special_tokens_mask"]
80
+ overflow_mapping = model_outputs["overflow_mapping"]
81
+
82
+ offset_mapping = model_outputs["offset_mapping"] if model_outputs["offset_mapping"] is not None else None
83
+
84
+ maxes_seg = np.max(logits_seg, axis=-1, keepdims=True)
85
+ shifted_exp_seg = np.exp(logits_seg - maxes_seg)
86
+ scores_seg = shifted_exp_seg / shifted_exp_seg.sum(axis=-1, keepdims=True)
87
+
88
+ maxes_ref = np.max(logits_ref, axis=-1, keepdims=True)
89
+ shifted_exp_ref = np.exp(logits_ref - maxes_ref)
90
+ scores_ref = shifted_exp_ref / shifted_exp_ref.sum(axis=-1, keepdims=True)
91
+
92
+ pre_entities = self.gather_pre_entities(
93
+ sentence, input_ids, scores_seg, scores_ref, offset_mapping, special_tokens_mask
94
+ )
95
+ grouped_entities = self.aggregate(pre_entities)
96
+
97
+ cleaned_groups = []
98
+ for group in grouped_entities:
99
+ entities = [
100
+ entity
101
+ for entity in group
102
+ if entity.get("entity_group", None) not in ignore_labels
103
+ ]
104
+ cleaned_groups.append(entities)
105
+ return {
106
+ "number_of_references": len(cleaned_groups),
107
+ "references": cleaned_groups,
108
+ }
109
+
110
+ def gather_pre_entities(
111
+ self,
112
+ sentence: str,
113
+ input_ids: np.ndarray,
114
+ scores_seg: np.ndarray,
115
+ scores_ref: np.ndarray,
116
+ offset_mappings: Optional[List[Tuple[int, int]]],
117
+ special_tokens_masks: np.ndarray,
118
+ ) -> List[dict]:
119
+ """Fuse various numpy arrays into dicts with all the information needed for aggregation"""
120
+ pre_entities = []
121
+ for idx_list, (input_id, offset_mapping, special_tokens_mask, s_seg, s_ref) in enumerate(
122
+ zip(input_ids, offset_mappings, special_tokens_masks, scores_seg, scores_ref)):
123
+ for idx, iid in enumerate(input_id):
124
+
125
+ if special_tokens_mask[idx]:
126
+ continue
127
+
128
+ word = self.tokenizer.convert_ids_to_tokens(int(input_id[idx]))
129
+ if offset_mapping is not None:
130
+ start_ind, end_ind = offset_mapping[idx]
131
+ if not isinstance(start_ind, int):
132
+ if self.framework == "pt":
133
+ start_ind = start_ind.item()
134
+ end_ind = end_ind.item()
135
+ word_ref = sentence[start_ind:end_ind]
136
+ if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None):
137
+ is_subword = len(word) != len(word_ref)
138
+ else:
139
+ is_subword = len(word) == len(word_ref)
140
+
141
+ if int(input_id[idx]) == self.tokenizer.unk_token_id:
142
+ word = word_ref
143
+ is_subword = False
144
+ else:
145
+ start_ind = None
146
+ end_ind = None
147
+ is_subword = False
148
+
149
+ pre_entity = {
150
+ "word": word,
151
+ "scores_seg": s_seg[idx],
152
+ "scores_ref": s_ref[idx],
153
+ "start": start_ind,
154
+ "end": end_ind,
155
+ "index": idx,
156
+ "is_subword": is_subword,
157
+ }
158
+ pre_entities.append(pre_entity)
159
+ return pre_entities
160
+
161
+ def aggregate(self, pre_entities: List[dict]) -> List[dict]:
162
+ entities = self.aggregate_words(pre_entities)
163
+
164
+ return self.group_entities(entities)
165
+
166
+ def aggregate_word(self, entities: List[dict]) -> dict:
167
+ word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities])
168
+ scores_seg = entities[0]["scores_seg"]
169
+ idx_seg = scores_seg.argmax()
170
+ score_seg = scores_seg[idx_seg]
171
+ entity_seg = self.id2seg[idx_seg]
172
+
173
+ scores_ref = np.stack([entity["scores_ref"] for entity in entities])
174
+ indices_ref = scores_ref.argmax(axis=1)
175
+ idx_ref = 1 if all(indices_ref) else 0
176
+ # score_ref = 1
177
+ entity_ref = self.id2ref[idx_ref]
178
+
179
+ new_entity = {
180
+ "entity_seg": entity_seg,
181
+ "score_seg": score_seg,
182
+ "entity_ref": entity_ref,
183
+ # "score_ref": score_ref,
184
+ "word": word,
185
+ "start": entities[0]["start"],
186
+ "end": entities[-1]["end"],
187
+ }
188
+ return new_entity
189
+
190
+ def aggregate_words(self, entities: List[dict]) -> List[dict]:
191
+ """
192
+ Override tokens from a given word that disagree to force agreement on word boundaries.
193
+ Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
194
+ company| B-ENT I-ENT
195
+ """
196
+ word_entities = []
197
+ word_group = None
198
+ for entity in entities:
199
+ if word_group is None:
200
+ word_group = [entity]
201
+ elif entity["is_subword"]:
202
+ word_group.append(entity)
203
+ else:
204
+ word_entities.append(self.aggregate_word(word_group))
205
+ word_group = [entity]
206
+ word_entities.append(self.aggregate_word(word_group))
207
+ return word_entities
208
+
209
+ def group_entities(self, entities: List[dict]) -> List[dict]:
210
+ """
211
+ Find and group together the adjacent tokens with the same entity predicted.
212
+ Args:
213
+ entities (`dict`): The entities predicted by the pipeline.
214
+ """
215
+ entity_chunk = []
216
+ entity_chunk_disagg = []
217
+
218
+ for entity in entities:
219
+ if not entity_chunk_disagg:
220
+ entity_chunk_disagg.append(entity)
221
+ continue
222
+
223
+ bi_ref, tag_ref = self.get_tag(entity["entity_ref"])
224
+ last_bi_ref, last_tag_ref = self.get_tag(entity_chunk_disagg[-1]["entity_ref"])
225
+
226
+ if tag_ref == last_tag_ref and bi_ref != "B":
227
+ entity_chunk_disagg.append(entity)
228
+ else:
229
+ entity_chunk.append(entity_chunk_disagg)
230
+ entity_chunk_disagg = [entity]
231
+
232
+ if entity_chunk_disagg:
233
+ entity_chunk.append(entity_chunk_disagg)
234
+
235
+ entity_chunks_all = []
236
+
237
+ for chunk in entity_chunk:
238
+
239
+ entity_groups = []
240
+ entity_group_disagg = []
241
+
242
+ for entity in chunk:
243
+ if not entity_group_disagg:
244
+ entity_group_disagg.append(entity)
245
+ continue
246
+
247
+ bi_seg, tag_seg = self.get_tag(entity["entity_seg"])
248
+ last_bi_seg, last_tag_seg = self.get_tag(entity_group_disagg[-1]["entity_seg"])
249
+
250
+ if tag_seg == last_tag_seg and bi_seg != "B":
251
+ entity_group_disagg.append(entity)
252
+ else:
253
+ entity_groups.append(self.group_sub_entities(entity_group_disagg))
254
+ entity_group_disagg = [entity]
255
+
256
+ if entity_group_disagg:
257
+ entity_groups.append(self.group_sub_entities(entity_group_disagg))
258
+
259
+ entity_chunks_all.append(entity_groups)
260
+
261
+ return entity_chunks_all
262
+
263
+ def group_sub_entities(self, entities: List[dict]) -> dict:
264
+ """
265
+ Group together the adjacent tokens with the same entity predicted.
266
+ Args:
267
+ entities (`dict`): The entities predicted by the pipeline.
268
+ """
269
+ entity = entities[0]["entity_seg"].split("-")[-1]
270
+ scores = np.nanmean([entity["score_seg"] for entity in entities])
271
+ tokens = [entity["word"] for entity in entities]
272
+
273
+ entity_group = {
274
+ "entity_group": entity,
275
+ "score": np.mean(scores),
276
+ "word": " ".join(tokens),
277
+ "start": entities[0]["start"],
278
+ "end": entities[-1]["end"],
279
+ }
280
+ return entity_group
281
+
282
+ def get_tag(self, entity_name: str) -> Tuple[str, str]:
283
+ if entity_name.startswith("B-"):
284
+ bi = "B"
285
+ tag = entity_name[2:]
286
+ elif entity_name.startswith("I-"):
287
+ bi = "I"
288
+ tag = entity_name[2:]
289
+ else:
290
+ bi = "I"
291
+ tag = entity_name
292
+ return bi, tag
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62c24cdc13d4c9952d63718d6c9fa4c287974249e16b7ade6d5a85e7bbb75626
3
+ size 17082660
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "__type": "AddedToken",
7
+ "content": "<mask>",
8
+ "lstrip": true,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false
12
+ },
13
+ "model_max_length": 512,
14
+ "name_or_path": "xlm-roberta-large",
15
+ "pad_token": "<pad>",
16
+ "sep_token": "</s>",
17
+ "special_tokens_map_file": null,
18
+ "tokenizer_class": "XLMRobertaTokenizer",
19
+ "unk_token": "<unk>"
20
+ }