riccorl commited on
Commit
b5cf495
1 Parent(s): 52ea8eb

Upload model

Browse files
Files changed (3) hide show
  1. config.json +1 -0
  2. modeling_relik.py +65 -57
  3. pytorch_model.bin +2 -2
config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "activation": "gelu",
3
  "add_entity_embedding": null,
4
  "additional_special_symbols": 101,
 
1
  {
2
+ "_name_or_path": "models/hf_test/hf_test",
3
  "activation": "gelu",
4
  "add_entity_embedding": null,
5
  "additional_special_symbols": 101,
modeling_relik.py CHANGED
@@ -32,6 +32,7 @@ class RelikReaderSample:
32
  self._d[key] = value
33
  else:
34
  super().__setattr__(key, value)
 
35
 
36
 
37
  activation2functions = {
@@ -321,20 +322,40 @@ class RelikReaderSpanModel(PreTrainedModel):
321
  # flattening end predictions
322
  # (flattening can happen only if the
323
  # end boundaries were not predicted using the gold labels)
324
- if not self.training:
325
- flattened_end_predictions = torch.clone(ned_start_predictions)
326
- flattened_end_predictions[flattened_end_predictions > 0] = 0
327
-
328
- batch_start_predictions = list()
329
- for elem_idx in range(batch_size):
330
- batch_start_predictions.append(
331
- torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist()
332
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
- # check that the total number of start predictions
335
- # is equal to the end predictions
336
- total_start_predictions = sum(map(len, batch_start_predictions))
337
- total_end_predictions = len(ned_end_predictions)
338
  assert (
339
  total_start_predictions == 0
340
  or total_start_predictions == total_end_predictions
@@ -342,23 +363,9 @@ class RelikReaderSpanModel(PreTrainedModel):
342
  f"Total number of start predictions = {total_start_predictions}. "
343
  f"Total number of end predictions = {total_end_predictions}"
344
  )
345
-
346
- curr_end_pred_num = 0
347
- for elem_idx, bsp in enumerate(batch_start_predictions):
348
- for sp in bsp:
349
- ep = ned_end_predictions[curr_end_pred_num].item()
350
- if ep < sp:
351
- ep = sp
352
-
353
- # if we already set this span throw it (no overlap)
354
- if flattened_end_predictions[elem_idx, ep] == 1:
355
- ned_start_predictions[elem_idx, sp] = 0
356
- else:
357
- flattened_end_predictions[elem_idx, ep] = 1
358
-
359
- curr_end_pred_num += 1
360
-
361
  ned_end_predictions = flattened_end_predictions
 
 
362
 
363
  start_position, end_position = (
364
  (start_labels, end_labels)
@@ -461,7 +468,7 @@ class RelikReaderREModel(PreTrainedModel):
461
  self.transformer_model.resize_token_embeddings(
462
  self.transformer_model.config.vocab_size
463
  + config.additional_special_symbols
464
- + config.additional_special_symbols_types
465
  )
466
 
467
  # named entity detection layers
@@ -478,17 +485,21 @@ class RelikReaderREModel(PreTrainedModel):
478
  )
479
 
480
  if self.config.entity_type_loss and self.config.add_entity_embedding:
481
- input_hidden_ents = 3 * self.transformer_model.config.hidden_size
482
  else:
483
- input_hidden_ents = 2 * self.transformer_model.config.hidden_size
484
 
485
- self.re_subject_projector = self._get_projection_layer(
486
- config.activation, input_hidden=input_hidden_ents
 
 
 
487
  )
488
- self.re_object_projector = self._get_projection_layer(
489
- config.activation, input_hidden=input_hidden_ents
 
 
490
  )
491
- self.re_relation_projector = self._get_projection_layer(config.activation)
492
 
493
  if self.config.entity_type_loss or self.relation_disambiguation_loss:
494
  self.re_entities_projector = self._get_projection_layer(
@@ -516,6 +527,7 @@ class RelikReaderREModel(PreTrainedModel):
516
  self,
517
  activation: str,
518
  last_hidden: Optional[int] = None,
 
519
  input_hidden=None,
520
  layer_norm: bool = True,
521
  ) -> torch.nn.Sequential:
@@ -528,12 +540,12 @@ class RelikReaderREModel(PreTrainedModel):
528
  if input_hidden is None
529
  else input_hidden
530
  ),
531
- self.config.linears_hidden_size,
532
  ),
533
  activation2functions[activation],
534
  torch.nn.Dropout(0.1),
535
  torch.nn.Linear(
536
- self.config.linears_hidden_size,
537
  self.config.linears_hidden_size if last_hidden is None else last_hidden,
538
  ),
539
  ]
@@ -635,8 +647,13 @@ class RelikReaderREModel(PreTrainedModel):
635
  model_entity_features,
636
  special_symbols_features,
637
  ) -> torch.Tensor:
638
- model_subject_features = self.re_subject_projector(model_entity_features)
639
- model_object_features = self.re_object_projector(model_entity_features)
 
 
 
 
 
640
  special_symbols_start_representation = self.re_relation_projector(
641
  special_symbols_features
642
  )
@@ -720,13 +737,17 @@ class RelikReaderREModel(PreTrainedModel):
720
  end_labels: Optional[torch.Tensor] = None,
721
  disambiguation_labels: Optional[torch.Tensor] = None,
722
  relation_labels: Optional[torch.Tensor] = None,
723
- relation_threshold: float = 0.5,
724
  is_validation: bool = False,
725
  is_prediction: bool = False,
726
  use_predefined_spans: bool = False,
727
  *args,
728
  **kwargs,
729
  ) -> Dict[str, Any]:
 
 
 
 
730
  batch_size = input_ids.shape[0]
731
 
732
  model_features = self._get_model_features(
@@ -898,19 +919,7 @@ class RelikReaderREModel(PreTrainedModel):
898
  re_probabilities = torch.softmax(re_logits, dim=-1)
899
  # we set a thresshold instead of argmax in cause it needs to be tweaked
900
  re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
901
- # re_predictions = re_probabilities.argmax(dim=-1)
902
  re_probabilities = re_probabilities[:, :, :, :, 1]
903
- # re_logits, re_probabilities, re_predictions = (
904
- # torch.zeros(
905
- # [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
906
- # ).to(input_ids.device),
907
- # torch.zeros(
908
- # [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
909
- # ).to(input_ids.device),
910
- # torch.zeros(
911
- # [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
912
- # ).to(input_ids.device),
913
- # )
914
 
915
  else:
916
  (
@@ -981,10 +990,9 @@ class RelikReaderREModel(PreTrainedModel):
981
  ) / 4
982
  output_dict["ned_type_loss"] = ned_type_loss
983
  else:
984
- output_dict["loss"] = ((1 / 4) * (ned_start_loss + ned_end_loss)) + (
985
- (1 / 2) * relation_loss
986
  )
987
-
988
  output_dict["ned_start_loss"] = ned_start_loss
989
  output_dict["ned_end_loss"] = ned_end_loss
990
  output_dict["re_loss"] = relation_loss
 
32
  self._d[key] = value
33
  else:
34
  super().__setattr__(key, value)
35
+ self._d[key] = value
36
 
37
 
38
  activation2functions = {
 
322
  # flattening end predictions
323
  # (flattening can happen only if the
324
  # end boundaries were not predicted using the gold labels)
325
+ if not self.training and ned_end_logits is not None:
326
+ flattened_end_predictions = torch.zeros_like(ned_start_predictions)
327
+
328
+ row_indices, start_positions = torch.where(ned_start_predictions > 0)
329
+ ned_end_predictions[
330
+ ned_end_predictions < start_positions
331
+ ] = start_positions[ned_end_predictions < start_positions]
332
+
333
+ end_spans_repeated = (row_indices + 1) * seq_len + ned_end_predictions
334
+ cummax_values, _ = end_spans_repeated.cummax(dim=0)
335
+
336
+ end_spans_repeated = end_spans_repeated > torch.cat(
337
+ (end_spans_repeated[:1], cummax_values[:-1])
338
+ )
339
+ end_spans_repeated[0] = True
340
+
341
+ ned_start_predictions[
342
+ row_indices[~end_spans_repeated],
343
+ start_positions[~end_spans_repeated],
344
+ ] = 0
345
+
346
+ row_indices, start_positions, ned_end_predictions = (
347
+ row_indices[end_spans_repeated],
348
+ start_positions[end_spans_repeated],
349
+ ned_end_predictions[end_spans_repeated],
350
+ )
351
+
352
+ flattened_end_predictions[row_indices, ned_end_predictions] = 1
353
+
354
+ total_start_predictions, total_end_predictions = (
355
+ ned_start_predictions.sum(),
356
+ flattened_end_predictions.sum(),
357
+ )
358
 
 
 
 
 
359
  assert (
360
  total_start_predictions == 0
361
  or total_start_predictions == total_end_predictions
 
363
  f"Total number of start predictions = {total_start_predictions}. "
364
  f"Total number of end predictions = {total_end_predictions}"
365
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  ned_end_predictions = flattened_end_predictions
367
+ else:
368
+ ned_end_predictions = torch.zeros_like(ned_start_predictions)
369
 
370
  start_position, end_position = (
371
  (start_labels, end_labels)
 
468
  self.transformer_model.resize_token_embeddings(
469
  self.transformer_model.config.vocab_size
470
  + config.additional_special_symbols
471
+ + config.additional_special_symbols_types,
472
  )
473
 
474
  # named entity detection layers
 
485
  )
486
 
487
  if self.config.entity_type_loss and self.config.add_entity_embedding:
488
+ input_hidden_ents = 3 * self.config.linears_hidden_size
489
  else:
490
+ input_hidden_ents = 2 * self.config.linears_hidden_size
491
 
492
+ self.re_projector = self._get_projection_layer(
493
+ config.activation,
494
+ input_hidden=2 * self.transformer_model.config.hidden_size,
495
+ hidden=input_hidden_ents,
496
+ last_hidden=2 * self.config.linears_hidden_size,
497
  )
498
+
499
+ self.re_relation_projector = self._get_projection_layer(
500
+ config.activation,
501
+ input_hidden=self.transformer_model.config.hidden_size,
502
  )
 
503
 
504
  if self.config.entity_type_loss or self.relation_disambiguation_loss:
505
  self.re_entities_projector = self._get_projection_layer(
 
527
  self,
528
  activation: str,
529
  last_hidden: Optional[int] = None,
530
+ hidden: Optional[int] = None,
531
  input_hidden=None,
532
  layer_norm: bool = True,
533
  ) -> torch.nn.Sequential:
 
540
  if input_hidden is None
541
  else input_hidden
542
  ),
543
+ self.config.linears_hidden_size if hidden is None else hidden,
544
  ),
545
  activation2functions[activation],
546
  torch.nn.Dropout(0.1),
547
  torch.nn.Linear(
548
+ self.config.linears_hidden_size if hidden is None else hidden,
549
  self.config.linears_hidden_size if last_hidden is None else last_hidden,
550
  ),
551
  ]
 
647
  model_entity_features,
648
  special_symbols_features,
649
  ) -> torch.Tensor:
650
+ model_subject_object_features = self.re_projector(model_entity_features)
651
+ model_subject_features = model_subject_object_features[
652
+ :, :, : model_subject_object_features.shape[-1] // 2
653
+ ]
654
+ model_object_features = model_subject_object_features[
655
+ :, :, model_subject_object_features.shape[-1] // 2 :
656
+ ]
657
  special_symbols_start_representation = self.re_relation_projector(
658
  special_symbols_features
659
  )
 
737
  end_labels: Optional[torch.Tensor] = None,
738
  disambiguation_labels: Optional[torch.Tensor] = None,
739
  relation_labels: Optional[torch.Tensor] = None,
740
+ relation_threshold: float = None,
741
  is_validation: bool = False,
742
  is_prediction: bool = False,
743
  use_predefined_spans: bool = False,
744
  *args,
745
  **kwargs,
746
  ) -> Dict[str, Any]:
747
+ relation_threshold = (
748
+ self.config.threshold if relation_threshold is None else relation_threshold
749
+ )
750
+
751
  batch_size = input_ids.shape[0]
752
 
753
  model_features = self._get_model_features(
 
919
  re_probabilities = torch.softmax(re_logits, dim=-1)
920
  # we set a thresshold instead of argmax in cause it needs to be tweaked
921
  re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
 
922
  re_probabilities = re_probabilities[:, :, :, :, 1]
 
 
 
 
 
 
 
 
 
 
 
923
 
924
  else:
925
  (
 
990
  ) / 4
991
  output_dict["ned_type_loss"] = ned_type_loss
992
  else:
993
+ output_dict["loss"] = ((1 / 20) * (ned_start_loss + ned_end_loss)) + (
994
+ (9 / 10) * relation_loss
995
  )
 
996
  output_dict["ned_start_loss"] = ned_start_loss
997
  output_dict["ned_end_loss"] = ned_end_loss
998
  output_dict["re_loss"] = relation_loss
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:60ae577b1b5cb7bfff9776ccdbea4075e699976ecc9eb1e97afbf5a4d0933f1a
3
- size 747283514
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91ead9afa1a4b1d95a4d8b3997606b937616a928be232b9f13e01aa6cd766473
3
+ size 747280506