ragavsachdeva commited on
Commit
841cdcf
·
verified ·
1 Parent(s): 5deac87

Update modelling_magiv2.py

Browse files
Files changed (1) hide show
  1. modelling_magiv2.py +149 -2
modelling_magiv2.py CHANGED
@@ -2,7 +2,6 @@ from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel
2
  from transformers.models.conditional_detr.modeling_conditional_detr import (
3
  ConditionalDetrMLPPredictionHead,
4
  ConditionalDetrModelOutput,
5
- ConditionalDetrHungarianMatcher,
6
  inverse_sigmoid,
7
  )
8
  from .configuration_magiv2 import Magiv2Config
@@ -17,6 +16,7 @@ from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order
17
  import pulp
18
  import scipy
19
  import numpy as np
 
20
 
21
  class Magiv2Model(PreTrainedModel):
22
  config_class = Magiv2Config
@@ -611,4 +611,151 @@ class Magiv2Model(PreTrainedModel):
611
  if apply_sigmoid:
612
  text_tail_affinities = text_tail_affinities.sigmoid()
613
  affinity_matrices.append(text_tail_affinities)
614
- return affinity_matrices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers.models.conditional_detr.modeling_conditional_detr import (
3
  ConditionalDetrMLPPredictionHead,
4
  ConditionalDetrModelOutput,
 
5
  inverse_sigmoid,
6
  )
7
  from .configuration_magiv2 import Magiv2Config
 
16
  import pulp
17
  import scipy
18
  import numpy as np
19
+ from scipy.optimize import linear_sum_assignment
20
 
21
  class Magiv2Model(PreTrainedModel):
22
  config_class = Magiv2Config
 
611
  if apply_sigmoid:
612
  text_tail_affinities = text_tail_affinities.sigmoid()
613
  affinity_matrices.append(text_tail_affinities)
614
+ return affinity_matrices
615
+
616
+ # Copied from transformers.models.detr.modeling_detr._upcast
617
+ def _upcast(t):
618
+ # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
619
+ if t.is_floating_point():
620
+ return t if t.dtype in (torch.float32, torch.float64) else t.float()
621
+ else:
622
+ return t if t.dtype in (torch.int32, torch.int64) else t.int()
623
+
624
+
625
+ # Copied from transformers.models.detr.modeling_detr.box_area
626
+ def box_area(boxes):
627
+ """
628
+ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
629
+
630
+ Args:
631
+ boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
632
+ Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
633
+ < x2` and `0 <= y1 < y2`.
634
+
635
+ Returns:
636
+ `torch.FloatTensor`: a tensor containing the area for each box.
637
+ """
638
+ boxes = _upcast(boxes)
639
+ return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
640
+
641
+
642
+ # Copied from transformers.models.detr.modeling_detr.box_iou
643
+ def box_iou(boxes1, boxes2):
644
+ area1 = box_area(boxes1)
645
+ area2 = box_area(boxes2)
646
+
647
+ left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
648
+ right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
649
+
650
+ width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
651
+ inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
652
+
653
+ union = area1[:, None] + area2 - inter
654
+
655
+ iou = inter / union
656
+ return iou, union
657
+
658
+
659
+ # Copied from transformers.models.detr.modeling_detr.generalized_box_iou
660
+ def generalized_box_iou(boxes1, boxes2):
661
+ """
662
+ Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
663
+
664
+ Returns:
665
+ `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
666
+ """
667
+ # degenerate boxes gives inf / nan results
668
+ # so do an early check
669
+ if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
670
+ raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
671
+ if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
672
+ raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
673
+ iou, union = box_iou(boxes1, boxes2)
674
+
675
+ top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
676
+ bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
677
+
678
+ width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
679
+ area = width_height[:, :, 0] * width_height[:, :, 1]
680
+
681
+ return iou - (area - union) / area
682
+
683
+
684
+ # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
685
+ class ConditionalDetrHungarianMatcher(nn.Module):
686
+ """
687
+ This class computes an assignment between the targets and the predictions of the network.
688
+
689
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
690
+ predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
691
+ un-matched (and thus treated as non-objects).
692
+
693
+ Args:
694
+ class_cost:
695
+ The relative weight of the classification error in the matching cost.
696
+ bbox_cost:
697
+ The relative weight of the L1 error of the bounding box coordinates in the matching cost.
698
+ giou_cost:
699
+ The relative weight of the giou loss of the bounding box in the matching cost.
700
+ """
701
+
702
+ def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
703
+ super().__init__()
704
+
705
+ self.class_cost = class_cost
706
+ self.bbox_cost = bbox_cost
707
+ self.giou_cost = giou_cost
708
+ if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
709
+ raise ValueError("All costs of the Matcher can't be 0")
710
+
711
+ @torch.no_grad()
712
+ def forward(self, outputs, targets):
713
+ """
714
+ Args:
715
+ outputs (`dict`):
716
+ A dictionary that contains at least these entries:
717
+ * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
718
+ * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
719
+ targets (`List[dict]`):
720
+ A list of targets (len(targets) = batch_size), where each target is a dict containing:
721
+ * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
722
+ ground-truth
723
+ objects in the target) containing the class labels
724
+ * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
725
+
726
+ Returns:
727
+ `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
728
+ - index_i is the indices of the selected predictions (in order)
729
+ - index_j is the indices of the corresponding selected targets (in order)
730
+ For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
731
+ """
732
+ batch_size, num_queries = outputs["logits"].shape[:2]
733
+
734
+ # We flatten to compute the cost matrices in a batch
735
+ out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
736
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
737
+
738
+ # Also concat the target labels and boxes
739
+ target_ids = torch.cat([v["class_labels"] for v in targets])
740
+ target_bbox = torch.cat([v["boxes"] for v in targets])
741
+
742
+ # Compute the classification cost.
743
+ alpha = 0.25
744
+ gamma = 2.0
745
+ neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
746
+ pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
747
+ class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
748
+
749
+ # Compute the L1 cost between boxes
750
+ bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
751
+
752
+ # Compute the giou cost between boxes
753
+ giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
754
+
755
+ # Final cost matrix
756
+ cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
757
+ cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
758
+
759
+ sizes = [len(v["boxes"]) for v in targets]
760
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
761
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]