sumit
commited on
Commit
•
26aaaca
1
Parent(s):
b20f15c
add start,end to entities
Browse files- BertForJointParsing.py +23 -14
- BertForPrefixMarking.py +1 -2
BertForJointParsing.py
CHANGED
@@ -200,8 +200,9 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
200 |
if self.prefix is not None:
|
201 |
inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
|
202 |
else:
|
203 |
-
inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_tensors='pt')
|
204 |
-
|
|
|
205 |
# Copy the tensors to the right device, and parse!
|
206 |
inputs = {k:v.to(self.device) for k,v in inputs.items()}
|
207 |
output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
|
@@ -230,7 +231,7 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
230 |
|
231 |
# NER logits each sentence gets a list(tuple(word, ner))
|
232 |
if output.ner_logits is not None:
|
233 |
-
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label)):
|
234 |
if per_token_ner:
|
235 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
236 |
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
@@ -247,17 +248,18 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
247 |
def aggregate_ner_tokens(predictions):
|
248 |
entities = []
|
249 |
prev = None
|
250 |
-
for word,pred in predictions:
|
251 |
# O does nothing
|
252 |
if pred == 'O': prev = None
|
253 |
# B- || I-entity != prev (different entity or none)
|
254 |
elif pred.startswith('B-') or pred[2:] != prev:
|
255 |
prev = pred[2:]
|
256 |
-
entities.append(
|
257 |
-
else:
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
261 |
|
262 |
def merge_token_list(src, update, key):
|
263 |
for token_src, token_update in zip(src, update):
|
@@ -272,9 +274,9 @@ def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFa
|
|
272 |
else: ret.append(token)
|
273 |
return ret
|
274 |
|
275 |
-
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
|
276 |
-
input_ids = inputs['input_ids']
|
277 |
-
|
278 |
predictions = torch.argmax(logits, dim=-1)
|
279 |
batch_ret = []
|
280 |
for batch_idx in range(len(sentences)):
|
@@ -286,11 +288,18 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
|
|
286 |
if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue
|
287 |
|
288 |
token = tokenizer._convert_id_to_token(token_id)
|
|
|
|
|
|
|
289 |
# wordpieces should just be appended to the previous word
|
|
|
|
|
290 |
if token.startswith('##'):
|
291 |
-
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
|
292 |
continue
|
293 |
-
|
|
|
|
|
294 |
return batch_ret
|
295 |
|
296 |
def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
|
|
|
200 |
if self.prefix is not None:
|
201 |
inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
|
202 |
else:
|
203 |
+
inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
|
204 |
+
|
205 |
+
offset_mapping = inputs.pop('offset_mapping')
|
206 |
# Copy the tensors to the right device, and parse!
|
207 |
inputs = {k:v.to(self.device) for k,v in inputs.items()}
|
208 |
output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
|
|
|
231 |
|
232 |
# NER logits each sentence gets a list(tuple(word, ner))
|
233 |
if output.ner_logits is not None:
|
234 |
+
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
|
235 |
if per_token_ner:
|
236 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
237 |
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
|
|
248 |
def aggregate_ner_tokens(predictions):
|
249 |
entities = []
|
250 |
prev = None
|
251 |
+
for word, pred, start, end in predictions:
|
252 |
# O does nothing
|
253 |
if pred == 'O': prev = None
|
254 |
# B- || I-entity != prev (different entity or none)
|
255 |
elif pred.startswith('B-') or pred[2:] != prev:
|
256 |
prev = pred[2:]
|
257 |
+
entities.append([[word], prev, start, end])
|
258 |
+
else:
|
259 |
+
entities[-1][0].append(word)
|
260 |
+
entities[-1][3] = end
|
261 |
+
|
262 |
+
return [dict(phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end in entities]
|
263 |
|
264 |
def merge_token_list(src, update, key):
|
265 |
for token_src, token_update in zip(src, update):
|
|
|
274 |
else: ret.append(token)
|
275 |
return ret
|
276 |
|
277 |
+
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
|
278 |
+
input_ids = inputs['input_ids']
|
279 |
+
|
280 |
predictions = torch.argmax(logits, dim=-1)
|
281 |
batch_ret = []
|
282 |
for batch_idx in range(len(sentences)):
|
|
|
288 |
if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue
|
289 |
|
290 |
token = tokenizer._convert_id_to_token(token_id)
|
291 |
+
|
292 |
+
# get the offsets for this token
|
293 |
+
start_pos, end_pos = offset_mapping[batch_idx, tok_idx]
|
294 |
# wordpieces should just be appended to the previous word
|
295 |
+
# we modify the last token in ret
|
296 |
+
# by discarding the original end position and replacing it with the new token's end position
|
297 |
if token.startswith('##'):
|
298 |
+
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item())
|
299 |
continue
|
300 |
+
# for each token, we append a tuple containing: token, label, start position, end position
|
301 |
+
ret.append((token, id2label[predictions[batch_idx, tok_idx].item()], start_pos.item(), end_pos.item()))
|
302 |
+
|
303 |
return batch_ret
|
304 |
|
305 |
def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
|
BertForPrefixMarking.py
CHANGED
@@ -184,8 +184,7 @@ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenize
|
|
184 |
return ret
|
185 |
|
186 |
def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest', truncation=True):
|
187 |
-
inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_tensors='pt')
|
188 |
-
|
189 |
# create our prefix_id_options array which will be like the input ids shape but with an addtional
|
190 |
# dimension containing for each prefix whether it can be for that word
|
191 |
prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long)
|
|
|
184 |
return ret
|
185 |
|
186 |
def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest', truncation=True):
|
187 |
+
inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
|
|
|
188 |
# create our prefix_id_options array which will be like the input ids shape but with an addtional
|
189 |
# dimension containing for each prefix whether it can be for that word
|
190 |
prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long)
|