Update modelling_magiv2.py
Browse files- modelling_magiv2.py +149 -2
modelling_magiv2.py
CHANGED
@@ -2,7 +2,6 @@ from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel
|
|
2 |
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
3 |
ConditionalDetrMLPPredictionHead,
|
4 |
ConditionalDetrModelOutput,
|
5 |
-
ConditionalDetrHungarianMatcher,
|
6 |
inverse_sigmoid,
|
7 |
)
|
8 |
from .configuration_magiv2 import Magiv2Config
|
@@ -17,6 +16,7 @@ from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order
|
|
17 |
import pulp
|
18 |
import scipy
|
19 |
import numpy as np
|
|
|
20 |
|
21 |
class Magiv2Model(PreTrainedModel):
|
22 |
config_class = Magiv2Config
|
@@ -611,4 +611,151 @@ class Magiv2Model(PreTrainedModel):
|
|
611 |
if apply_sigmoid:
|
612 |
text_tail_affinities = text_tail_affinities.sigmoid()
|
613 |
affinity_matrices.append(text_tail_affinities)
|
614 |
-
return affinity_matrices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
3 |
ConditionalDetrMLPPredictionHead,
|
4 |
ConditionalDetrModelOutput,
|
|
|
5 |
inverse_sigmoid,
|
6 |
)
|
7 |
from .configuration_magiv2 import Magiv2Config
|
|
|
16 |
import pulp
|
17 |
import scipy
|
18 |
import numpy as np
|
19 |
+
from scipy.optimize import linear_sum_assignment
|
20 |
|
21 |
class Magiv2Model(PreTrainedModel):
|
22 |
config_class = Magiv2Config
|
|
|
611 |
if apply_sigmoid:
|
612 |
text_tail_affinities = text_tail_affinities.sigmoid()
|
613 |
affinity_matrices.append(text_tail_affinities)
|
614 |
+
return affinity_matrices
|
615 |
+
|
616 |
+
# Copied from transformers.models.detr.modeling_detr._upcast
|
617 |
+
def _upcast(t):
|
618 |
+
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
619 |
+
if t.is_floating_point():
|
620 |
+
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
621 |
+
else:
|
622 |
+
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
623 |
+
|
624 |
+
|
625 |
+
# Copied from transformers.models.detr.modeling_detr.box_area
|
626 |
+
def box_area(boxes):
|
627 |
+
"""
|
628 |
+
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
629 |
+
|
630 |
+
Args:
|
631 |
+
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
632 |
+
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
633 |
+
< x2` and `0 <= y1 < y2`.
|
634 |
+
|
635 |
+
Returns:
|
636 |
+
`torch.FloatTensor`: a tensor containing the area for each box.
|
637 |
+
"""
|
638 |
+
boxes = _upcast(boxes)
|
639 |
+
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
640 |
+
|
641 |
+
|
642 |
+
# Copied from transformers.models.detr.modeling_detr.box_iou
|
643 |
+
def box_iou(boxes1, boxes2):
|
644 |
+
area1 = box_area(boxes1)
|
645 |
+
area2 = box_area(boxes2)
|
646 |
+
|
647 |
+
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
648 |
+
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
649 |
+
|
650 |
+
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
651 |
+
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
652 |
+
|
653 |
+
union = area1[:, None] + area2 - inter
|
654 |
+
|
655 |
+
iou = inter / union
|
656 |
+
return iou, union
|
657 |
+
|
658 |
+
|
659 |
+
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
660 |
+
def generalized_box_iou(boxes1, boxes2):
|
661 |
+
"""
|
662 |
+
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
663 |
+
|
664 |
+
Returns:
|
665 |
+
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
666 |
+
"""
|
667 |
+
# degenerate boxes gives inf / nan results
|
668 |
+
# so do an early check
|
669 |
+
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
670 |
+
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
671 |
+
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
672 |
+
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
673 |
+
iou, union = box_iou(boxes1, boxes2)
|
674 |
+
|
675 |
+
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
676 |
+
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
677 |
+
|
678 |
+
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
679 |
+
area = width_height[:, :, 0] * width_height[:, :, 1]
|
680 |
+
|
681 |
+
return iou - (area - union) / area
|
682 |
+
|
683 |
+
|
684 |
+
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
|
685 |
+
class ConditionalDetrHungarianMatcher(nn.Module):
|
686 |
+
"""
|
687 |
+
This class computes an assignment between the targets and the predictions of the network.
|
688 |
+
|
689 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
690 |
+
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
691 |
+
un-matched (and thus treated as non-objects).
|
692 |
+
|
693 |
+
Args:
|
694 |
+
class_cost:
|
695 |
+
The relative weight of the classification error in the matching cost.
|
696 |
+
bbox_cost:
|
697 |
+
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
698 |
+
giou_cost:
|
699 |
+
The relative weight of the giou loss of the bounding box in the matching cost.
|
700 |
+
"""
|
701 |
+
|
702 |
+
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
703 |
+
super().__init__()
|
704 |
+
|
705 |
+
self.class_cost = class_cost
|
706 |
+
self.bbox_cost = bbox_cost
|
707 |
+
self.giou_cost = giou_cost
|
708 |
+
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
709 |
+
raise ValueError("All costs of the Matcher can't be 0")
|
710 |
+
|
711 |
+
@torch.no_grad()
|
712 |
+
def forward(self, outputs, targets):
|
713 |
+
"""
|
714 |
+
Args:
|
715 |
+
outputs (`dict`):
|
716 |
+
A dictionary that contains at least these entries:
|
717 |
+
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
718 |
+
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
719 |
+
targets (`List[dict]`):
|
720 |
+
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
721 |
+
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
722 |
+
ground-truth
|
723 |
+
objects in the target) containing the class labels
|
724 |
+
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
725 |
+
|
726 |
+
Returns:
|
727 |
+
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
728 |
+
- index_i is the indices of the selected predictions (in order)
|
729 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
730 |
+
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
731 |
+
"""
|
732 |
+
batch_size, num_queries = outputs["logits"].shape[:2]
|
733 |
+
|
734 |
+
# We flatten to compute the cost matrices in a batch
|
735 |
+
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
736 |
+
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
737 |
+
|
738 |
+
# Also concat the target labels and boxes
|
739 |
+
target_ids = torch.cat([v["class_labels"] for v in targets])
|
740 |
+
target_bbox = torch.cat([v["boxes"] for v in targets])
|
741 |
+
|
742 |
+
# Compute the classification cost.
|
743 |
+
alpha = 0.25
|
744 |
+
gamma = 2.0
|
745 |
+
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
746 |
+
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
747 |
+
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
|
748 |
+
|
749 |
+
# Compute the L1 cost between boxes
|
750 |
+
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
751 |
+
|
752 |
+
# Compute the giou cost between boxes
|
753 |
+
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
754 |
+
|
755 |
+
# Final cost matrix
|
756 |
+
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
757 |
+
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
758 |
+
|
759 |
+
sizes = [len(v["boxes"]) for v in targets]
|
760 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
761 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|