Upload model
Browse files- config.json +1 -0
- modeling_relik.py +65 -57
- pytorch_model.bin +2 -2
config.json
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
{
|
|
|
2 |
"activation": "gelu",
|
3 |
"add_entity_embedding": null,
|
4 |
"additional_special_symbols": 101,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "models/hf_test/hf_test",
|
3 |
"activation": "gelu",
|
4 |
"add_entity_embedding": null,
|
5 |
"additional_special_symbols": 101,
|
modeling_relik.py
CHANGED
@@ -32,6 +32,7 @@ class RelikReaderSample:
|
|
32 |
self._d[key] = value
|
33 |
else:
|
34 |
super().__setattr__(key, value)
|
|
|
35 |
|
36 |
|
37 |
activation2functions = {
|
@@ -321,20 +322,40 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
321 |
# flattening end predictions
|
322 |
# (flattening can happen only if the
|
323 |
# end boundaries were not predicted using the gold labels)
|
324 |
-
if not self.training:
|
325 |
-
flattened_end_predictions = torch.
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
-
# check that the total number of start predictions
|
335 |
-
# is equal to the end predictions
|
336 |
-
total_start_predictions = sum(map(len, batch_start_predictions))
|
337 |
-
total_end_predictions = len(ned_end_predictions)
|
338 |
assert (
|
339 |
total_start_predictions == 0
|
340 |
or total_start_predictions == total_end_predictions
|
@@ -342,23 +363,9 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
342 |
f"Total number of start predictions = {total_start_predictions}. "
|
343 |
f"Total number of end predictions = {total_end_predictions}"
|
344 |
)
|
345 |
-
|
346 |
-
curr_end_pred_num = 0
|
347 |
-
for elem_idx, bsp in enumerate(batch_start_predictions):
|
348 |
-
for sp in bsp:
|
349 |
-
ep = ned_end_predictions[curr_end_pred_num].item()
|
350 |
-
if ep < sp:
|
351 |
-
ep = sp
|
352 |
-
|
353 |
-
# if we already set this span throw it (no overlap)
|
354 |
-
if flattened_end_predictions[elem_idx, ep] == 1:
|
355 |
-
ned_start_predictions[elem_idx, sp] = 0
|
356 |
-
else:
|
357 |
-
flattened_end_predictions[elem_idx, ep] = 1
|
358 |
-
|
359 |
-
curr_end_pred_num += 1
|
360 |
-
|
361 |
ned_end_predictions = flattened_end_predictions
|
|
|
|
|
362 |
|
363 |
start_position, end_position = (
|
364 |
(start_labels, end_labels)
|
@@ -461,7 +468,7 @@ class RelikReaderREModel(PreTrainedModel):
|
|
461 |
self.transformer_model.resize_token_embeddings(
|
462 |
self.transformer_model.config.vocab_size
|
463 |
+ config.additional_special_symbols
|
464 |
-
+ config.additional_special_symbols_types
|
465 |
)
|
466 |
|
467 |
# named entity detection layers
|
@@ -478,17 +485,21 @@ class RelikReaderREModel(PreTrainedModel):
|
|
478 |
)
|
479 |
|
480 |
if self.config.entity_type_loss and self.config.add_entity_embedding:
|
481 |
-
input_hidden_ents = 3 * self.
|
482 |
else:
|
483 |
-
input_hidden_ents = 2 * self.
|
484 |
|
485 |
-
self.
|
486 |
-
config.activation,
|
|
|
|
|
|
|
487 |
)
|
488 |
-
|
489 |
-
|
|
|
|
|
490 |
)
|
491 |
-
self.re_relation_projector = self._get_projection_layer(config.activation)
|
492 |
|
493 |
if self.config.entity_type_loss or self.relation_disambiguation_loss:
|
494 |
self.re_entities_projector = self._get_projection_layer(
|
@@ -516,6 +527,7 @@ class RelikReaderREModel(PreTrainedModel):
|
|
516 |
self,
|
517 |
activation: str,
|
518 |
last_hidden: Optional[int] = None,
|
|
|
519 |
input_hidden=None,
|
520 |
layer_norm: bool = True,
|
521 |
) -> torch.nn.Sequential:
|
@@ -528,12 +540,12 @@ class RelikReaderREModel(PreTrainedModel):
|
|
528 |
if input_hidden is None
|
529 |
else input_hidden
|
530 |
),
|
531 |
-
self.config.linears_hidden_size,
|
532 |
),
|
533 |
activation2functions[activation],
|
534 |
torch.nn.Dropout(0.1),
|
535 |
torch.nn.Linear(
|
536 |
-
self.config.linears_hidden_size,
|
537 |
self.config.linears_hidden_size if last_hidden is None else last_hidden,
|
538 |
),
|
539 |
]
|
@@ -635,8 +647,13 @@ class RelikReaderREModel(PreTrainedModel):
|
|
635 |
model_entity_features,
|
636 |
special_symbols_features,
|
637 |
) -> torch.Tensor:
|
638 |
-
|
639 |
-
|
|
|
|
|
|
|
|
|
|
|
640 |
special_symbols_start_representation = self.re_relation_projector(
|
641 |
special_symbols_features
|
642 |
)
|
@@ -720,13 +737,17 @@ class RelikReaderREModel(PreTrainedModel):
|
|
720 |
end_labels: Optional[torch.Tensor] = None,
|
721 |
disambiguation_labels: Optional[torch.Tensor] = None,
|
722 |
relation_labels: Optional[torch.Tensor] = None,
|
723 |
-
relation_threshold: float =
|
724 |
is_validation: bool = False,
|
725 |
is_prediction: bool = False,
|
726 |
use_predefined_spans: bool = False,
|
727 |
*args,
|
728 |
**kwargs,
|
729 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
730 |
batch_size = input_ids.shape[0]
|
731 |
|
732 |
model_features = self._get_model_features(
|
@@ -898,19 +919,7 @@ class RelikReaderREModel(PreTrainedModel):
|
|
898 |
re_probabilities = torch.softmax(re_logits, dim=-1)
|
899 |
# we set a thresshold instead of argmax in cause it needs to be tweaked
|
900 |
re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
|
901 |
-
# re_predictions = re_probabilities.argmax(dim=-1)
|
902 |
re_probabilities = re_probabilities[:, :, :, :, 1]
|
903 |
-
# re_logits, re_probabilities, re_predictions = (
|
904 |
-
# torch.zeros(
|
905 |
-
# [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
|
906 |
-
# ).to(input_ids.device),
|
907 |
-
# torch.zeros(
|
908 |
-
# [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
|
909 |
-
# ).to(input_ids.device),
|
910 |
-
# torch.zeros(
|
911 |
-
# [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
|
912 |
-
# ).to(input_ids.device),
|
913 |
-
# )
|
914 |
|
915 |
else:
|
916 |
(
|
@@ -981,10 +990,9 @@ class RelikReaderREModel(PreTrainedModel):
|
|
981 |
) / 4
|
982 |
output_dict["ned_type_loss"] = ned_type_loss
|
983 |
else:
|
984 |
-
output_dict["loss"] = ((1 /
|
985 |
-
(
|
986 |
)
|
987 |
-
|
988 |
output_dict["ned_start_loss"] = ned_start_loss
|
989 |
output_dict["ned_end_loss"] = ned_end_loss
|
990 |
output_dict["re_loss"] = relation_loss
|
|
|
32 |
self._d[key] = value
|
33 |
else:
|
34 |
super().__setattr__(key, value)
|
35 |
+
self._d[key] = value
|
36 |
|
37 |
|
38 |
activation2functions = {
|
|
|
322 |
# flattening end predictions
|
323 |
# (flattening can happen only if the
|
324 |
# end boundaries were not predicted using the gold labels)
|
325 |
+
if not self.training and ned_end_logits is not None:
|
326 |
+
flattened_end_predictions = torch.zeros_like(ned_start_predictions)
|
327 |
+
|
328 |
+
row_indices, start_positions = torch.where(ned_start_predictions > 0)
|
329 |
+
ned_end_predictions[
|
330 |
+
ned_end_predictions < start_positions
|
331 |
+
] = start_positions[ned_end_predictions < start_positions]
|
332 |
+
|
333 |
+
end_spans_repeated = (row_indices + 1) * seq_len + ned_end_predictions
|
334 |
+
cummax_values, _ = end_spans_repeated.cummax(dim=0)
|
335 |
+
|
336 |
+
end_spans_repeated = end_spans_repeated > torch.cat(
|
337 |
+
(end_spans_repeated[:1], cummax_values[:-1])
|
338 |
+
)
|
339 |
+
end_spans_repeated[0] = True
|
340 |
+
|
341 |
+
ned_start_predictions[
|
342 |
+
row_indices[~end_spans_repeated],
|
343 |
+
start_positions[~end_spans_repeated],
|
344 |
+
] = 0
|
345 |
+
|
346 |
+
row_indices, start_positions, ned_end_predictions = (
|
347 |
+
row_indices[end_spans_repeated],
|
348 |
+
start_positions[end_spans_repeated],
|
349 |
+
ned_end_predictions[end_spans_repeated],
|
350 |
+
)
|
351 |
+
|
352 |
+
flattened_end_predictions[row_indices, ned_end_predictions] = 1
|
353 |
+
|
354 |
+
total_start_predictions, total_end_predictions = (
|
355 |
+
ned_start_predictions.sum(),
|
356 |
+
flattened_end_predictions.sum(),
|
357 |
+
)
|
358 |
|
|
|
|
|
|
|
|
|
359 |
assert (
|
360 |
total_start_predictions == 0
|
361 |
or total_start_predictions == total_end_predictions
|
|
|
363 |
f"Total number of start predictions = {total_start_predictions}. "
|
364 |
f"Total number of end predictions = {total_end_predictions}"
|
365 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
ned_end_predictions = flattened_end_predictions
|
367 |
+
else:
|
368 |
+
ned_end_predictions = torch.zeros_like(ned_start_predictions)
|
369 |
|
370 |
start_position, end_position = (
|
371 |
(start_labels, end_labels)
|
|
|
468 |
self.transformer_model.resize_token_embeddings(
|
469 |
self.transformer_model.config.vocab_size
|
470 |
+ config.additional_special_symbols
|
471 |
+
+ config.additional_special_symbols_types,
|
472 |
)
|
473 |
|
474 |
# named entity detection layers
|
|
|
485 |
)
|
486 |
|
487 |
if self.config.entity_type_loss and self.config.add_entity_embedding:
|
488 |
+
input_hidden_ents = 3 * self.config.linears_hidden_size
|
489 |
else:
|
490 |
+
input_hidden_ents = 2 * self.config.linears_hidden_size
|
491 |
|
492 |
+
self.re_projector = self._get_projection_layer(
|
493 |
+
config.activation,
|
494 |
+
input_hidden=2 * self.transformer_model.config.hidden_size,
|
495 |
+
hidden=input_hidden_ents,
|
496 |
+
last_hidden=2 * self.config.linears_hidden_size,
|
497 |
)
|
498 |
+
|
499 |
+
self.re_relation_projector = self._get_projection_layer(
|
500 |
+
config.activation,
|
501 |
+
input_hidden=self.transformer_model.config.hidden_size,
|
502 |
)
|
|
|
503 |
|
504 |
if self.config.entity_type_loss or self.relation_disambiguation_loss:
|
505 |
self.re_entities_projector = self._get_projection_layer(
|
|
|
527 |
self,
|
528 |
activation: str,
|
529 |
last_hidden: Optional[int] = None,
|
530 |
+
hidden: Optional[int] = None,
|
531 |
input_hidden=None,
|
532 |
layer_norm: bool = True,
|
533 |
) -> torch.nn.Sequential:
|
|
|
540 |
if input_hidden is None
|
541 |
else input_hidden
|
542 |
),
|
543 |
+
self.config.linears_hidden_size if hidden is None else hidden,
|
544 |
),
|
545 |
activation2functions[activation],
|
546 |
torch.nn.Dropout(0.1),
|
547 |
torch.nn.Linear(
|
548 |
+
self.config.linears_hidden_size if hidden is None else hidden,
|
549 |
self.config.linears_hidden_size if last_hidden is None else last_hidden,
|
550 |
),
|
551 |
]
|
|
|
647 |
model_entity_features,
|
648 |
special_symbols_features,
|
649 |
) -> torch.Tensor:
|
650 |
+
model_subject_object_features = self.re_projector(model_entity_features)
|
651 |
+
model_subject_features = model_subject_object_features[
|
652 |
+
:, :, : model_subject_object_features.shape[-1] // 2
|
653 |
+
]
|
654 |
+
model_object_features = model_subject_object_features[
|
655 |
+
:, :, model_subject_object_features.shape[-1] // 2 :
|
656 |
+
]
|
657 |
special_symbols_start_representation = self.re_relation_projector(
|
658 |
special_symbols_features
|
659 |
)
|
|
|
737 |
end_labels: Optional[torch.Tensor] = None,
|
738 |
disambiguation_labels: Optional[torch.Tensor] = None,
|
739 |
relation_labels: Optional[torch.Tensor] = None,
|
740 |
+
relation_threshold: float = None,
|
741 |
is_validation: bool = False,
|
742 |
is_prediction: bool = False,
|
743 |
use_predefined_spans: bool = False,
|
744 |
*args,
|
745 |
**kwargs,
|
746 |
) -> Dict[str, Any]:
|
747 |
+
relation_threshold = (
|
748 |
+
self.config.threshold if relation_threshold is None else relation_threshold
|
749 |
+
)
|
750 |
+
|
751 |
batch_size = input_ids.shape[0]
|
752 |
|
753 |
model_features = self._get_model_features(
|
|
|
919 |
re_probabilities = torch.softmax(re_logits, dim=-1)
|
920 |
# we set a thresshold instead of argmax in cause it needs to be tweaked
|
921 |
re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
|
|
|
922 |
re_probabilities = re_probabilities[:, :, :, :, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
923 |
|
924 |
else:
|
925 |
(
|
|
|
990 |
) / 4
|
991 |
output_dict["ned_type_loss"] = ned_type_loss
|
992 |
else:
|
993 |
+
output_dict["loss"] = ((1 / 20) * (ned_start_loss + ned_end_loss)) + (
|
994 |
+
(9 / 10) * relation_loss
|
995 |
)
|
|
|
996 |
output_dict["ned_start_loss"] = ned_start_loss
|
997 |
output_dict["ned_end_loss"] = ned_end_loss
|
998 |
output_dict["re_loss"] = relation_loss
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91ead9afa1a4b1d95a4d8b3997606b937616a928be232b9f13e01aa6cd766473
|
3 |
+
size 747280506
|