z-uo commited on
Commit
d863531
1 Parent(s): 8485499
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+
3
+ import torch
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+
7
+ import gradio as gr
8
+
9
+ # import sys
10
+ # sys.path.insert(0, './')
11
+ from test import create_letr, draw_fig
12
+ from models.preprocessing import *
13
+ from models.misc import nested_tensor_from_tensor_list
14
+
15
+
16
+ model = create_letr()
17
+
18
+ # PREPARE PREPROCESSING
19
+ test_size = 256
20
+ # transform_test = transforms.Compose([
21
+ # transforms.Resize((test_size)),
22
+ # transforms.ToTensor(),
23
+ # transforms.Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
24
+ # ])
25
+ normalize = Compose([
26
+ ToTensor(),
27
+ Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
28
+ Resize([test_size]),
29
+ ])
30
+
31
+
32
+ def predict(inp):
33
+ image = Image.fromarray(inp.astype('uint8'), 'RGB')
34
+ h, w = image.height, image.width
35
+ orig_size = torch.as_tensor([int(h), int(w)])
36
+
37
+ img = normalize(image)
38
+ inputs = nested_tensor_from_tensor_list([img])
39
+
40
+ with torch.no_grad():
41
+ outputs = model(inputs)[0]
42
+
43
+ draw_fig(image, outputs, orig_size)
44
+
45
+ return image
46
+
47
+
48
+ inputs = gr.inputs.Image()
49
+ outputs = gr.outputs.Image()
50
+ gr.Interface(
51
+ fn=predict,
52
+ inputs=inputs,
53
+ outputs=outputs,
54
+ examples=["demo.png", "tappeto-per-calibrazione.jpg"],
55
+ title="LETR",
56
+ description="Model for line detection..."
57
+ ).launch()
checkpoint0024.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26725e48335937731ac968a3fbde602d296ca3edcf93b79f4f76f356ad3a4ff9
3
+ size 380893769
demo.png ADDED
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (143 Bytes). View file
 
models/__pycache__/backbone.cpython-38.pyc ADDED
Binary file (4.75 kB). View file
 
models/__pycache__/letr.cpython-38.pyc ADDED
Binary file (13.2 kB). View file
 
models/__pycache__/letr_stack.cpython-38.pyc ADDED
Binary file (12.2 kB). View file
 
models/__pycache__/matcher.cpython-38.pyc ADDED
Binary file (4.12 kB). View file
 
models/__pycache__/misc.cpython-38.pyc ADDED
Binary file (14.6 kB). View file
 
models/__pycache__/multi_head_attention.cpython-38.pyc ADDED
Binary file (19.7 kB). View file
 
models/__pycache__/position_encoding.cpython-38.pyc ADDED
Binary file (3.65 kB). View file
 
models/__pycache__/preprocessing.cpython-38.pyc ADDED
Binary file (2.98 kB). View file
 
models/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (9 kB). View file
 
models/backbone.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LETR Backbone modules.
3
+ modified based on https://github.com/facebookresearch/detr/blob/master/models/backbone.py
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ from torch import nn
11
+ from torchvision.models._utils import IntermediateLayerGetter
12
+ from typing import Dict, List
13
+
14
+ from .misc import NestedTensor, is_main_process
15
+
16
+ from .position_encoding import build_position_encoding
17
+
18
+
19
+ class FrozenBatchNorm2d(torch.nn.Module):
20
+ """
21
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
22
+
23
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
24
+ without which any other models than torchvision.models.resnet[18,34,50,101]
25
+ produce nans.
26
+ """
27
+
28
+ def __init__(self, n):
29
+ super(FrozenBatchNorm2d, self).__init__()
30
+ self.register_buffer("weight", torch.ones(n))
31
+ self.register_buffer("bias", torch.zeros(n))
32
+ self.register_buffer("running_mean", torch.zeros(n))
33
+ self.register_buffer("running_var", torch.ones(n))
34
+
35
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
36
+ missing_keys, unexpected_keys, error_msgs):
37
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
38
+ if num_batches_tracked_key in state_dict:
39
+ del state_dict[num_batches_tracked_key]
40
+
41
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
42
+ state_dict, prefix, local_metadata, strict,
43
+ missing_keys, unexpected_keys, error_msgs)
44
+
45
+ def forward(self, x):
46
+ # move reshapes to the beginning
47
+ # to make it fuser-friendly
48
+ w = self.weight.reshape(1, -1, 1, 1)
49
+ b = self.bias.reshape(1, -1, 1, 1)
50
+ rv = self.running_var.reshape(1, -1, 1, 1)
51
+ rm = self.running_mean.reshape(1, -1, 1, 1)
52
+ eps = 1e-5
53
+ scale = w * (rv + eps).rsqrt()
54
+ bias = b - rm * scale
55
+ return x * scale + bias
56
+
57
+
58
+ class BackboneBase(nn.Module):
59
+
60
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
61
+ super().__init__()
62
+ for name, parameter in backbone.named_parameters():
63
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
64
+ parameter.requires_grad_(False)
65
+ if return_interm_layers:
66
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
67
+ else:
68
+ return_layers = {'layer4': "0"}
69
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
70
+ self.num_channels = num_channels
71
+
72
+ def forward(self, tensor_list: NestedTensor):
73
+ xs = self.body(tensor_list.tensors)
74
+ out: Dict[str, NestedTensor] = {}
75
+ for name, x in xs.items():
76
+
77
+ m = tensor_list.mask
78
+ assert m is not None
79
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
80
+ out[name] = NestedTensor(x, mask)
81
+ return out
82
+
83
+
84
+ class Backbone(BackboneBase):
85
+ """ResNet backbone with frozen BatchNorm."""
86
+ def __init__(self, name: str,
87
+ train_backbone: bool,
88
+ return_interm_layers: bool,
89
+ dilation: bool):
90
+ backbone = getattr(torchvision.models, name)(
91
+ replace_stride_with_dilation=[False, False, dilation],
92
+ pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
93
+ num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
94
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
95
+
96
+
97
+ class Joiner(nn.Sequential):
98
+ def __init__(self, backbone, position_embedding):
99
+ super().__init__(backbone, position_embedding)
100
+
101
+ def forward(self, tensor_list: NestedTensor):
102
+ xs = self[0](tensor_list)
103
+ out: List[NestedTensor] = []
104
+ pos = []
105
+ for name, x in xs.items():
106
+ out.append(x)
107
+ # position encoding
108
+ pos.append(self[1](x).to(x.tensors.dtype))
109
+
110
+ return out, pos
111
+
112
+
113
+ def build_backbone(args):
114
+ position_embedding = build_position_encoding(args)
115
+ train_backbone = args.lr_backbone > 0
116
+ return_interm_layers = True
117
+ backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
118
+ model = Joiner(backbone, position_embedding)
119
+ model.num_channels = backbone.num_channels
120
+ return model
models/letr.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file provides coarse stage LETR definition
3
+ Modified based on https://github.com/facebookresearch/detr/blob/master/models/backbone.py
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from .misc import (NestedTensor, nested_tensor_from_tensor_list,
10
+ accuracy, get_world_size, interpolate,
11
+ is_dist_avail_and_initialized)
12
+
13
+ from .backbone import build_backbone
14
+ from .matcher import build_matcher
15
+ from .transformer import build_transformer
16
+ from .letr_stack import LETRstack
17
+ import numpy as np
18
+
19
+ class LETR(nn.Module):
20
+ """ This is the LETR module that performs object detection """
21
+ def __init__(self, backbone, transformer, num_classes, num_queries, args, aux_loss=False):
22
+ super().__init__()
23
+ self.num_queries = num_queries
24
+ self.transformer = transformer
25
+ hidden_dim = transformer.d_model
26
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
27
+
28
+ self.lines_embed = MLP(hidden_dim, hidden_dim, 4, 3)
29
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
30
+
31
+ channel = [256, 512, 1024, 2048]
32
+ self.input_proj = nn.Conv2d(channel[args.layer1_num], hidden_dim, kernel_size=1)
33
+
34
+ self.backbone = backbone
35
+ self.aux_loss = aux_loss
36
+ self.args = args
37
+
38
+ def forward(self, samples, postprocessors=None, targets=None, criterion=None):
39
+ if isinstance(samples, (list, torch.Tensor)):
40
+ samples = nested_tensor_from_tensor_list(samples)
41
+
42
+ features, pos = self.backbone(samples)
43
+
44
+ num = self.args.layer1_num
45
+ src, mask = features[num].decompose()
46
+ assert mask is not None
47
+ hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[num])[0]
48
+
49
+ outputs_class = self.class_embed(hs)
50
+ outputs_coord = self.lines_embed(hs).sigmoid()
51
+ out = {'pred_logits': outputs_class[-1], 'pred_lines': outputs_coord[-1]}
52
+ if self.aux_loss:
53
+ out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
54
+ return out
55
+
56
+ @torch.jit.unused
57
+ def _set_aux_loss(self, outputs_class, outputs_coord):
58
+ return [{'pred_logits': a, 'pred_lines': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
59
+
60
+ class SetCriterion(nn.Module):
61
+
62
+ def __init__(self, num_classes, weight_dict, eos_coef, losses, args, matcher=None):
63
+
64
+ super().__init__()
65
+ self.num_classes = num_classes
66
+
67
+ self.matcher = matcher
68
+
69
+ self.weight_dict = weight_dict
70
+ self.eos_coef = eos_coef
71
+ self.losses = losses
72
+ empty_weight = torch.ones(self.num_classes + 1)
73
+ empty_weight[-1] = self.eos_coef
74
+ self.register_buffer('empty_weight', empty_weight)
75
+ self.args = args
76
+ try:
77
+ self.args.label_loss_params = eval(self.args.label_loss_params) # Convert the string to dict.
78
+ except:
79
+ pass
80
+
81
+ def loss_lines_labels(self, outputs, targets, num_items, log=False, origin_indices=None):
82
+ """Classification loss (NLL)
83
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_lines]
84
+ """
85
+ assert 'pred_logits' in outputs
86
+ src_logits = outputs['pred_logits']
87
+
88
+ idx = self._get_src_permutation_idx(origin_indices)
89
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, origin_indices)])
90
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes,
91
+ dtype=torch.int64, device=src_logits.device)
92
+ target_classes[idx] = target_classes_o
93
+
94
+ if self.args.label_loss_func == 'cross_entropy':
95
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
96
+ elif self.args.label_loss_func == 'focal_loss':
97
+ loss_ce = self.label_focal_loss(src_logits.transpose(1, 2), target_classes, self.empty_weight, **self.args.label_loss_params)
98
+ else:
99
+ raise ValueError()
100
+
101
+ losses = {'loss_ce': loss_ce}
102
+ return losses
103
+
104
+ def label_focal_loss(self, input, target, weight, gamma=2.0):
105
+ """ Focal loss for label prediction. """
106
+ # In our case, target has 2 classes: 0 for foreground (i.e. line) and 1 for background.
107
+ # The weight here can serve as the alpha hyperparameter in focal loss. However, in focal loss,
108
+ #
109
+ # Ref: https://github.com/facebookresearch/DETR/blob/699bf53f3e3ecd4f000007b8473eda6a08a8bed6/models/segmentation.py#L190
110
+ # Ref: https://medium.com/visionwizard/understanding-focal-loss-a-quick-read-b914422913e7
111
+
112
+ # input shape: [batch size, #classes, #queries]
113
+ # target shape: [batch size, #queries]
114
+ # weight shape: [#classes]
115
+
116
+ prob = F.softmax(input, 1) # Shape: [batch size, #classes, #queries].
117
+ ce_loss = F.cross_entropy(input, target, weight, reduction='none') # Shape: [batch size, #queries].
118
+ p_t = prob[:,1,:] * target + prob[:,0,:] * (1 - target) # Shape: [batch size, #queries]. Note: prob[:,0,:] + prob[:,1,:] should be 1.
119
+ loss = ce_loss * ((1 - p_t) ** gamma)
120
+ loss = loss.mean() # Original label loss (i.e. cross entropy) does not consider the #lines, so we also do not consider that.
121
+ return loss
122
+
123
+ @torch.no_grad()
124
+ def loss_cardinality(self, outputs, targets, num_items, origin_indices=None):
125
+ """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty lines
126
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
127
+ """
128
+ pred_logits = outputs['pred_logits']
129
+ device = pred_logits.device
130
+ tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
131
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
132
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
133
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
134
+ losses = {'cardinality_error': card_err}
135
+ return losses
136
+
137
+ def loss_lines_POST(self, outputs, targets, num_items, origin_indices=None):
138
+ assert 'POST_pred_lines' in outputs
139
+
140
+ if outputs['POST_pred_lines'].shape[1] == 1000:
141
+ idx = self._get_src_permutation_idx(origin_indices)
142
+
143
+ src_lines = outputs['POST_pred_lines'][idx]
144
+
145
+ else:
146
+ src_lines = outputs['POST_pred_lines'].squeeze(0)
147
+
148
+
149
+ target_lines = torch.cat([t['lines'][i] for t, (_, i) in zip(targets, origin_indices)], dim=0)
150
+
151
+ loss_line = F.l1_loss(src_lines, target_lines, reduction='none')
152
+
153
+ losses = {}
154
+ losses['loss_line'] = loss_line.sum() / num_items
155
+
156
+ return losses
157
+
158
+ def loss_lines(self, outputs, targets, num_items, origin_indices=None):
159
+ assert 'pred_lines' in outputs
160
+
161
+ idx = self._get_src_permutation_idx(origin_indices)
162
+
163
+ src_lines = outputs['pred_lines'][idx]
164
+ target_lines = torch.cat([t['lines'][i] for t, (_, i) in zip(targets, origin_indices)], dim=0)
165
+
166
+ loss_line = F.l1_loss(src_lines, target_lines, reduction='none')
167
+
168
+ losses = {}
169
+ losses['loss_line'] = loss_line.sum() / num_items
170
+
171
+ return losses
172
+
173
+ def _get_src_permutation_idx(self, indices):
174
+ # permute predictions following indices
175
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
176
+ src_idx = torch.cat([src for (src, _) in indices])
177
+ return batch_idx, src_idx
178
+
179
+ def _get_tgt_permutation_idx(self, indices):
180
+ # permute targets following indices
181
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
182
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
183
+ return batch_idx, tgt_idx
184
+
185
+ def get_loss(self, loss, outputs, targets, num_items, **kwargs):
186
+
187
+ loss_map = {
188
+ 'POST_lines_labels': self.loss_lines_labels,
189
+ 'POST_lines': self.loss_lines,
190
+ 'lines_labels': self.loss_lines_labels,
191
+ 'cardinality': self.loss_cardinality,
192
+ 'lines': self.loss_lines,
193
+ }
194
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
195
+ return loss_map[loss](outputs, targets, num_items, **kwargs)
196
+
197
+ def forward(self, outputs, targets, origin_indices=None):
198
+ """ This performs the loss computation.
199
+ Parameters:
200
+ outputs: dict of tensors, see the output specification of the model for the format
201
+ targets: list of dicts, such that len(targets) == batch_size.
202
+ The expected keys in each dict depends on the losses applied, see each loss' doc
203
+ """
204
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
205
+
206
+
207
+ origin_indices = self.matcher(outputs_without_aux, targets)
208
+
209
+
210
+ num_items = sum(len(t["labels"]) for t in targets)
211
+
212
+ num_items = torch.as_tensor([num_items], dtype=torch.float, device=next(iter(outputs.values())).device)
213
+ if is_dist_avail_and_initialized():
214
+ torch.distributed.all_reduce(num_items)
215
+ num_items = torch.clamp(num_items / get_world_size(), min=1).item()
216
+
217
+ # Compute all the requested losses
218
+ losses = {}
219
+ for loss in self.losses:
220
+ losses.update(self.get_loss(loss, outputs, targets, num_items, origin_indices=origin_indices))
221
+
222
+
223
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
224
+ aux_name = 'aux_outputs'
225
+ if aux_name in outputs:
226
+ for i, aux_outputs in enumerate(outputs[aux_name]):
227
+
228
+ origin_indices = self.matcher(aux_outputs, targets)
229
+
230
+ for loss in self.losses:
231
+
232
+ kwargs = {}
233
+ if loss == 'labels':
234
+ # Logging is enabled only for the last layer
235
+ kwargs = {'log': False}
236
+
237
+ l_dict = self.get_loss(loss, aux_outputs, targets, num_items, origin_indices=origin_indices, **kwargs)
238
+
239
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
240
+ losses.update(l_dict)
241
+
242
+ return losses
243
+
244
+
245
+ class PostProcess_Line(nn.Module):
246
+
247
+ """ This module converts the model's output into the format expected by the coco api"""
248
+ @torch.no_grad()
249
+ def forward(self, outputs, target_sizes, output_type):
250
+ """ Perform the computation
251
+ Parameters:
252
+ outputs: raw outputs of the model
253
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
254
+ For evaluation, this must be the original image size (before any data augmentation)
255
+ For visualization, this should be the image size after data augment, but before padding
256
+ """
257
+ if output_type == "prediction":
258
+ out_logits, out_line = outputs['pred_logits'], outputs['pred_lines']
259
+
260
+ assert len(out_logits) == len(target_sizes)
261
+ assert target_sizes.shape[1] == 2
262
+
263
+ prob = F.softmax(out_logits, -1)
264
+ scores, labels = prob[..., :-1].max(-1)
265
+
266
+ # convert to [x0, y0, x1, y1] format
267
+ img_h, img_w = target_sizes.unbind(1)
268
+
269
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
270
+ lines = out_line * scale_fct[:, None, :]
271
+
272
+ results = [{'scores': s, 'labels': l, 'lines': b} for s, l, b in zip(scores, labels, lines)]
273
+ elif output_type == "prediction_POST":
274
+ out_logits, out_line = outputs['pred_logits'], outputs['POST_pred_lines']
275
+
276
+ assert len(out_logits) == len(target_sizes)
277
+ assert target_sizes.shape[1] == 2
278
+
279
+ prob = F.softmax(out_logits, -1)
280
+ scores, labels = prob[..., :-1].max(-1)
281
+
282
+ # convert to [x0, y0, x1, y1] format
283
+ img_h, img_w = target_sizes.unbind(1)
284
+
285
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
286
+ lines = out_line * scale_fct[:, None, :]
287
+
288
+ results = [{'scores': s, 'labels': l, 'lines': b} for s, l, b in zip(scores, labels, lines)]
289
+ elif output_type == "ground_truth":
290
+ results = []
291
+ for dic in outputs:
292
+ lines = dic['lines']
293
+ img_h, img_w = target_sizes.unbind(1)
294
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
295
+ scaled_lines = lines * scale_fct
296
+ results.append({'labels': dic['labels'], 'lines': scaled_lines, 'image_id': dic['image_id']})
297
+ else:
298
+ assert False
299
+ return results
300
+
301
+
302
+ class MLP(nn.Module):
303
+ """ Very simple multi-layer perceptron (also called FFN)"""
304
+
305
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
306
+ super().__init__()
307
+ self.num_layers = num_layers
308
+ h = [hidden_dim] * (num_layers - 1)
309
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
310
+
311
+ def forward(self, x):
312
+ for i, layer in enumerate(self.layers):
313
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
314
+ return x
315
+
316
+
317
+ def build(args):
318
+ num_classes = 1
319
+
320
+ device = torch.device(args.device)
321
+
322
+ backbone = build_backbone(args)
323
+
324
+ transformer = build_transformer(args)
325
+
326
+ model = LETR(
327
+ backbone,
328
+ transformer,
329
+ num_classes=num_classes,
330
+ num_queries=args.num_queries,
331
+ args=args,
332
+ aux_loss=args.aux_loss,
333
+ )
334
+
335
+ if args.LETRpost:
336
+ model = LETRstack(model, args=args)
337
+
338
+
339
+ matcher = build_matcher(args, type='origin_line')
340
+
341
+ losses = []
342
+ weight_dict = {}
343
+
344
+ if args.LETRpost:
345
+ losses.append('POST_lines_labels')
346
+ losses.append('POST_lines')
347
+ weight_dict['loss_ce'] = 1
348
+ weight_dict['loss_line'] = args.line_loss_coef
349
+ aux_layer = args.second_dec_layers
350
+ else:
351
+ losses.append('lines_labels')
352
+ losses.append('lines')
353
+ weight_dict['loss_ce'] = 1
354
+ weight_dict['loss_line'] = args.line_loss_coef
355
+ aux_layer = args.dec_layers
356
+
357
+ if args.aux_loss:
358
+ aux_weight_dict = {}
359
+ for i in range(aux_layer - 1):
360
+ aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
361
+ weight_dict.update(aux_weight_dict)
362
+
363
+
364
+ criterion = SetCriterion(num_classes, weight_dict=weight_dict, eos_coef=args.eos_coef, losses=losses, args=args, matcher=matcher)
365
+ criterion.to(device)
366
+
367
+
368
+ postprocessors = {'line': PostProcess_Line()}
369
+
370
+
371
+ return model, criterion, postprocessors
models/letr_stack.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file provides fine stage LETR definition
3
+
4
+ """
5
+ import io
6
+ from collections import defaultdict
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch import Tensor
13
+ from PIL import Image
14
+ from .misc import NestedTensor, nested_tensor_from_tensor_list
15
+ import copy
16
+
17
+
18
+ class LETRstack(nn.Module):
19
+ def __init__(self, letr, args):
20
+ super().__init__()
21
+ self.letr = letr
22
+ self.backbone = self.letr.backbone
23
+
24
+ if args.layer1_frozen:
25
+ # freeze backbone, encoder, decoder
26
+ for n, p in self.named_parameters():
27
+ p.requires_grad_(False)
28
+
29
+ hidden_dim, nheads = letr.transformer.d_model, letr.transformer.nhead
30
+
31
+ # add new input proj layer
32
+ channel = [256, 512, 1024, 2048]
33
+ self.input_proj = nn.Conv2d(channel[args.layer2_num], hidden_dim, kernel_size=1)
34
+
35
+ # add new transformer encoder decoder
36
+ self.transformer = Transformer( d_model=args.second_hidden_dim, dropout=args.second_dropout, nhead=args.second_nheads,
37
+ dim_feedforward=args.second_dim_feedforward, num_encoder_layers=args.second_enc_layers,
38
+ num_decoder_layers=args.second_dec_layers, normalize_before=args.second_pre_norm, return_intermediate_dec=True,)
39
+
40
+ # output layer
41
+ self.class_embed = nn.Linear(hidden_dim, 1 + 1)
42
+ self.lines_embed = MLP(hidden_dim, hidden_dim, 4, 3)
43
+
44
+
45
+ self.aux_loss=args.aux_loss
46
+ self.args = args
47
+
48
+ def forward(self, samples, postprocessors=None, targets=None, criterion=None):
49
+ if isinstance(samples, (list, torch.Tensor)):
50
+ samples = nested_tensor_from_tensor_list(samples)
51
+
52
+ # backbone
53
+ features, pos = self.letr.backbone(samples)
54
+
55
+ # layer 1
56
+ l1_num = self.args.layer1_num
57
+ src1, mask1 = features[l1_num].decompose()
58
+ assert mask1 is not None
59
+
60
+ # layer 1 transformer
61
+ hs1, _ = self.letr.transformer(self.letr.input_proj(src1), mask1, self.letr.query_embed.weight, pos[l1_num])
62
+
63
+ # layer 2
64
+ l2_num = self.args.layer2_num
65
+ src2, mask2 = features[l2_num].decompose()
66
+ src2 = self.input_proj(src2)
67
+
68
+ # layer 2 transformer
69
+ hs2, memory, _ = self.transformer(src2, mask2, hs1[-1], pos[l2_num])
70
+
71
+ outputs_class = self.class_embed(hs2)
72
+ outputs_coord = self.lines_embed(hs2).sigmoid()
73
+ out = {}
74
+ out["pred_logits"] = outputs_class[-1]
75
+ out["pred_lines"] = outputs_coord[-1]
76
+
77
+ if self.aux_loss:
78
+ out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
79
+
80
+ return out, None
81
+
82
+ @torch.jit.unused
83
+ def _set_aux_loss(self, outputs_class, outputs_coord):
84
+ # this is a workaround to make torchscript happy, as torchscript
85
+ # doesn't support dictionary with non-homogeneous values, such
86
+ # as a dict having both a Tensor and a list.
87
+ return [{'pred_logits': a, 'pred_lines': b}
88
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
89
+
90
+ @torch.jit.unused
91
+ def _set_aux_loss_POST(self, outputs_class, outputs_coord):
92
+ # this is a workaround to make torchscript happy, as torchscript
93
+ # doesn't support dictionary with non-homogeneous values, such
94
+ # as a dict having both a Tensor and a list.
95
+ return [{'POST_pred_lines': b} for b in outputs_coord[:-1]]
96
+
97
+ def _expand(tensor, length: int):
98
+ return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
99
+
100
+ class MLP(nn.Module):
101
+ """ Very simple multi-layer perceptron (also called FFN)"""
102
+
103
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
104
+ super().__init__()
105
+ self.num_layers = num_layers
106
+ h = [hidden_dim] * (num_layers - 1)
107
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
108
+
109
+ def forward(self, x):
110
+ for i, layer in enumerate(self.layers):
111
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
112
+ return x
113
+
114
+
115
+ class Transformer(nn.Module):
116
+
117
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
118
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
119
+ activation="relu", normalize_before=False,
120
+ return_intermediate_dec=False):
121
+ super().__init__()
122
+
123
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
124
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
125
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
126
+
127
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
128
+ decoder_norm = nn.LayerNorm(d_model)
129
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
130
+ return_intermediate=return_intermediate_dec)
131
+
132
+ self._reset_parameters()
133
+
134
+ self.d_model = d_model
135
+ self.nhead = nhead
136
+
137
+ def _reset_parameters(self):
138
+ for p in self.parameters():
139
+ if p.dim() > 1:
140
+ nn.init.xavier_uniform_(p)
141
+
142
+ def forward(self, src, mask, query_embed, pos_embed):
143
+ # flatten NxCxHxW to HWxNxC
144
+ bs, c, h, w = src.shape
145
+ src = src.flatten(2).permute(2, 0, 1)
146
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
147
+ mask = mask.flatten(1)
148
+
149
+ query_embed = query_embed.permute(1, 0, 2)
150
+ tgt = torch.zeros_like(query_embed)
151
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
152
+ hs, attn_output_weights = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
153
+ return hs.transpose(1, 2), memory, attn_output_weights
154
+
155
+ class TransformerEncoder(nn.Module):
156
+
157
+ def __init__(self, encoder_layer, num_layers, norm=None):
158
+ super().__init__()
159
+ self.layers = _get_clones(encoder_layer, num_layers)
160
+ self.num_layers = num_layers
161
+ self.norm = norm
162
+
163
+ def forward(self, src,
164
+ mask: Optional[Tensor] = None,
165
+ src_key_padding_mask: Optional[Tensor] = None,
166
+ pos: Optional[Tensor] = None):
167
+ output = src
168
+
169
+ for layer in self.layers:
170
+ output = layer(output, src_mask=mask,
171
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
172
+
173
+ if self.norm is not None:
174
+ output = self.norm(output)
175
+
176
+ return output
177
+
178
+ class TransformerDecoder(nn.Module):
179
+
180
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
181
+ super().__init__()
182
+ self.layers = _get_clones(decoder_layer, num_layers)
183
+ self.num_layers = num_layers
184
+ self.norm = norm
185
+ self.return_intermediate = return_intermediate
186
+
187
+ def forward(self, tgt, memory,
188
+ tgt_mask: Optional[Tensor] = None,
189
+ memory_mask: Optional[Tensor] = None,
190
+ tgt_key_padding_mask: Optional[Tensor] = None,
191
+ memory_key_padding_mask: Optional[Tensor] = None,
192
+ pos: Optional[Tensor] = None,
193
+ query_pos: Optional[Tensor] = None):
194
+ output = tgt
195
+
196
+ intermediate = []
197
+ attn_output_weights_list = []
198
+ for layer in self.layers:
199
+ output, attn_output_weights = layer(output, memory, tgt_mask=tgt_mask,
200
+ memory_mask=memory_mask,
201
+ tgt_key_padding_mask=tgt_key_padding_mask,
202
+ memory_key_padding_mask=memory_key_padding_mask,
203
+ pos=pos, query_pos=query_pos)
204
+ if self.return_intermediate:
205
+ intermediate.append(self.norm(output))
206
+ attn_output_weights_list.append(attn_output_weights)
207
+ if self.norm is not None:
208
+ output = self.norm(output)
209
+ if self.return_intermediate:
210
+ intermediate.pop()
211
+ intermediate.append(output)
212
+
213
+ if self.return_intermediate:
214
+ return torch.stack(intermediate), attn_output_weights_list
215
+
216
+ return output.unsqueeze(0), attn_output_weights
217
+
218
+
219
+ class TransformerEncoderLayer(nn.Module):
220
+
221
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
222
+ activation="relu", normalize_before=False):
223
+ super().__init__()
224
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
225
+ # Implementation of Feedforward model
226
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
227
+ self.dropout = nn.Dropout(dropout)
228
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
229
+
230
+ self.norm1 = nn.LayerNorm(d_model)
231
+ self.norm2 = nn.LayerNorm(d_model)
232
+ self.dropout1 = nn.Dropout(dropout)
233
+ self.dropout2 = nn.Dropout(dropout)
234
+
235
+ self.activation = _get_activation_fn(activation)
236
+ self.normalize_before = normalize_before
237
+
238
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
239
+ return tensor if pos is None else tensor + pos
240
+
241
+ def forward_post(self,
242
+ src,
243
+ src_mask: Optional[Tensor] = None,
244
+ src_key_padding_mask: Optional[Tensor] = None,
245
+ pos: Optional[Tensor] = None):
246
+ q = k = self.with_pos_embed(src, pos)
247
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
248
+ key_padding_mask=src_key_padding_mask)[0]
249
+ src = src + self.dropout1(src2)
250
+ src = self.norm1(src)
251
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
252
+ src = src + self.dropout2(src2)
253
+ src = self.norm2(src)
254
+ return src
255
+
256
+ def forward_pre(self, src,
257
+ src_mask: Optional[Tensor] = None,
258
+ src_key_padding_mask: Optional[Tensor] = None,
259
+ pos: Optional[Tensor] = None):
260
+ src2 = self.norm1(src)
261
+ q = k = self.with_pos_embed(src2, pos)
262
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
263
+ key_padding_mask=src_key_padding_mask)[0]
264
+ src = src + self.dropout1(src2)
265
+ src2 = self.norm2(src)
266
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
267
+ src = src + self.dropout2(src2)
268
+ return src
269
+
270
+ def forward(self, src,
271
+ src_mask: Optional[Tensor] = None,
272
+ src_key_padding_mask: Optional[Tensor] = None,
273
+ pos: Optional[Tensor] = None):
274
+ if self.normalize_before:
275
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
276
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
277
+
278
+
279
+ class TransformerDecoderLayer(nn.Module):
280
+
281
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
282
+ activation="relu", normalize_before=False):
283
+ super().__init__()
284
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
285
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
286
+ # Implementation of Feedforward model
287
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
288
+ self.dropout = nn.Dropout(dropout)
289
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
290
+
291
+ self.norm1 = nn.LayerNorm(d_model)
292
+ self.norm2 = nn.LayerNorm(d_model)
293
+ self.norm3 = nn.LayerNorm(d_model)
294
+ self.dropout1 = nn.Dropout(dropout)
295
+ self.dropout2 = nn.Dropout(dropout)
296
+ self.dropout3 = nn.Dropout(dropout)
297
+
298
+ self.activation = _get_activation_fn(activation)
299
+ self.normalize_before = normalize_before
300
+
301
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
302
+ return tensor if pos is None else tensor + pos
303
+
304
+ def forward_post(self, tgt, memory,
305
+ tgt_mask: Optional[Tensor] = None,
306
+ memory_mask: Optional[Tensor] = None,
307
+ tgt_key_padding_mask: Optional[Tensor] = None,
308
+ memory_key_padding_mask: Optional[Tensor] = None,
309
+ pos: Optional[Tensor] = None,
310
+ query_pos: Optional[Tensor] = None):
311
+ q = k = self.with_pos_embed(tgt, query_pos)
312
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
313
+ key_padding_mask=tgt_key_padding_mask)[0]
314
+ tgt = tgt + self.dropout1(tgt2)
315
+ tgt = self.norm1(tgt)
316
+ tgt2, attn_output_weights = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
317
+ key=self.with_pos_embed(memory, pos),
318
+ value=memory, attn_mask=memory_mask,
319
+ key_padding_mask=memory_key_padding_mask)
320
+ tgt = tgt + self.dropout2(tgt2)
321
+ tgt = self.norm2(tgt)
322
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
323
+ tgt = tgt + self.dropout3(tgt2)
324
+ tgt = self.norm3(tgt)
325
+ return tgt, attn_output_weights
326
+
327
+ def forward_pre(self, tgt, memory,
328
+ tgt_mask: Optional[Tensor] = None,
329
+ memory_mask: Optional[Tensor] = None,
330
+ tgt_key_padding_mask: Optional[Tensor] = None,
331
+ memory_key_padding_mask: Optional[Tensor] = None,
332
+ pos: Optional[Tensor] = None,
333
+ query_pos: Optional[Tensor] = None):
334
+ tgt2 = self.norm1(tgt)
335
+ q = k = self.with_pos_embed(tgt2, query_pos)
336
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
337
+ key_padding_mask=tgt_key_padding_mask)[0]
338
+ tgt = tgt + self.dropout1(tgt2)
339
+ tgt2 = self.norm2(tgt)
340
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
341
+ key=self.with_pos_embed(memory, pos),
342
+ value=memory, attn_mask=memory_mask,
343
+ key_padding_mask=memory_key_padding_mask)[0]
344
+ tgt = tgt + self.dropout2(tgt2)
345
+ tgt2 = self.norm3(tgt)
346
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
347
+ tgt = tgt + self.dropout3(tgt2)
348
+ return tgt
349
+
350
+ def forward(self, tgt, memory,
351
+ tgt_mask: Optional[Tensor] = None,
352
+ memory_mask: Optional[Tensor] = None,
353
+ tgt_key_padding_mask: Optional[Tensor] = None,
354
+ memory_key_padding_mask: Optional[Tensor] = None,
355
+ pos: Optional[Tensor] = None,
356
+ query_pos: Optional[Tensor] = None):
357
+ if self.normalize_before:
358
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
359
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
360
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
361
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
362
+
363
+
364
+ def _get_clones(module, N):
365
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
366
+
367
+ def _get_activation_fn(activation):
368
+ """Return an activation function given a string"""
369
+ if activation == "relu":
370
+ return F.relu
371
+ if activation == "gelu":
372
+ return F.gelu
373
+ if activation == "glu":
374
+ return F.glu
375
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
376
+
models/matcher.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modules to compute the matching cost and solve the corresponding LSAP.
3
+ """
4
+ import torch
5
+ from scipy.optimize import linear_sum_assignment
6
+ from torch import nn
7
+
8
+ class HungarianMatcher_Line(nn.Module):
9
+ """This class computes an assignment between the targets and the predictions of the network
10
+
11
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
12
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
13
+ while the others are un-matched (and thus treated as non-objects).
14
+ """
15
+
16
+ def __init__(self, cost_class: float = 1, cost_line: float = 1):
17
+ """Creates the matcher
18
+
19
+ Params:
20
+ cost_class: This is the relative weight of the classification error in the matching cost
21
+ cost_line: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
22
+ """
23
+ super().__init__()
24
+ self.cost_class = cost_class
25
+ self.cost_line = cost_line
26
+ assert cost_class != 0 or cost_line != 0, "all costs cant be 0"
27
+
28
+ @torch.no_grad()
29
+ def forward(self, outputs, targets):
30
+ """ Performs the matching
31
+
32
+ Params:
33
+ outputs: This is a dict that contains at least these entries:
34
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
35
+ "pred_lines": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
36
+
37
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
38
+ "labels": Tensor of dim [num_target_lines] (where num_target_lines is the number of ground-truth
39
+ objects in the target) containing the class labels
40
+ "lines": Tensor of dim [num_target_lines, 4] containing the target box coordinates
41
+
42
+ Returns:
43
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
44
+ - index_i is the indices of the selected predictions (in order)
45
+ - index_j is the indices of the corresponding selected targets (in order)
46
+ For each batch element, it holds:
47
+ len(index_i) = len(index_j) = min(num_queries, num_target_lines)
48
+ """
49
+ bs, num_queries = outputs["pred_logits"].shape[:2]
50
+
51
+ # We flatten to compute the cost matrices in a batch
52
+ out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
53
+
54
+ out_line = outputs["pred_lines"].flatten(0, 1) # [batch_size * num_queries, 4]
55
+ tgt_line = torch.cat([v["lines"] for v in targets])
56
+
57
+
58
+ # Also concat the target labels and lines
59
+ tgt_ids = torch.cat([v["labels"] for v in targets])
60
+
61
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
62
+ # but approximate it in 1 - proba[target class].
63
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
64
+ cost_class = -out_prob[:, tgt_ids]
65
+
66
+ # Compute the L1 cost between lines
67
+ cost_line = torch.cdist(out_line, tgt_line, p=1)
68
+
69
+ # Final cost matrix
70
+ C = self.cost_line * cost_line + self.cost_class * cost_class
71
+ C = C.view(bs, num_queries, -1).cpu()
72
+
73
+ sizes = [len(v["lines"]) for v in targets]
74
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
75
+
76
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
77
+
78
+
79
+
80
+ def build_matcher(args, type=None):
81
+ return HungarianMatcher_Line(cost_class=args.set_cost_class, cost_line=args.set_cost_line)
models/misc.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+
5
+ Mostly copy-paste from torchvision references.
6
+ """
7
+ import os
8
+ import subprocess
9
+ import time
10
+ from collections import defaultdict, deque
11
+ import datetime
12
+ import pickle
13
+ from typing import Optional, List
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ from torch import Tensor
18
+
19
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
20
+ import torchvision
21
+ if float(torchvision.__version__[:3]) < 0.7:
22
+ from torchvision.ops import _new_empty_tensor
23
+ from torchvision.ops.misc import _output_size
24
+
25
+
26
+ class SmoothedValue(object):
27
+ """Track a series of values and provide access to smoothed values over a
28
+ window or the global series average.
29
+ """
30
+
31
+ def __init__(self, window_size=20, fmt=None):
32
+ if fmt is None:
33
+ fmt = "{median:.4f} ({global_avg:.4f})"
34
+ self.deque = deque(maxlen=window_size)
35
+ self.total = 0.0
36
+ self.count = 0
37
+ self.fmt = fmt
38
+
39
+ def update(self, value, n=1):
40
+ self.deque.append(value)
41
+ self.count += n
42
+ self.total += value * n
43
+
44
+ def synchronize_between_processes(self):
45
+ """
46
+ Warning: does not synchronize the deque!
47
+ """
48
+ if not is_dist_avail_and_initialized():
49
+ return
50
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
51
+ dist.barrier()
52
+ dist.all_reduce(t)
53
+ t = t.tolist()
54
+ self.count = int(t[0])
55
+ self.total = t[1]
56
+
57
+ @property
58
+ def median(self):
59
+ d = torch.tensor(list(self.deque))
60
+ return d.median().item()
61
+
62
+ @property
63
+ def avg(self):
64
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
65
+ return d.mean().item()
66
+
67
+ @property
68
+ def global_avg(self):
69
+ return self.total / self.count
70
+
71
+ @property
72
+ def max(self):
73
+ return max(self.deque)
74
+
75
+ @property
76
+ def value(self):
77
+ return self.deque[-1]
78
+
79
+ def __str__(self):
80
+ return self.fmt.format(
81
+ median=self.median,
82
+ avg=self.avg,
83
+ global_avg=self.global_avg,
84
+ max=self.max,
85
+ value=self.value)
86
+
87
+
88
+ def all_gather(data):
89
+ """
90
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
91
+ Args:
92
+ data: any picklable object
93
+ Returns:
94
+ list[data]: list of data gathered from each rank
95
+ """
96
+ world_size = get_world_size()
97
+ if world_size == 1:
98
+ return [data]
99
+
100
+ # serialized to a Tensor
101
+ buffer = pickle.dumps(data)
102
+ storage = torch.ByteStorage.from_buffer(buffer)
103
+ tensor = torch.ByteTensor(storage).to("cuda")
104
+
105
+ # obtain Tensor size of each rank
106
+ local_size = torch.tensor([tensor.numel()], device="cuda")
107
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
108
+ dist.all_gather(size_list, local_size)
109
+ size_list = [int(size.item()) for size in size_list]
110
+ max_size = max(size_list)
111
+
112
+ # receiving Tensor from all ranks
113
+ # we pad the tensor because torch all_gather does not support
114
+ # gathering tensors of different shapes
115
+ tensor_list = []
116
+ for _ in size_list:
117
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
118
+ if local_size != max_size:
119
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
120
+ tensor = torch.cat((tensor, padding), dim=0)
121
+ dist.all_gather(tensor_list, tensor)
122
+
123
+ data_list = []
124
+ for size, tensor in zip(size_list, tensor_list):
125
+ buffer = tensor.cpu().numpy().tobytes()[:size]
126
+ data_list.append(pickle.loads(buffer))
127
+
128
+ return data_list
129
+
130
+
131
+ def reduce_dict(input_dict, average=True):
132
+ """
133
+ Args:
134
+ input_dict (dict): all the values will be reduced
135
+ average (bool): whether to do average or sum
136
+ Reduce the values in the dictionary from all processes so that all processes
137
+ have the averaged results. Returns a dict with the same fields as
138
+ input_dict, after reduction.
139
+ """
140
+ world_size = get_world_size()
141
+ if world_size < 2:
142
+ return input_dict
143
+ with torch.no_grad():
144
+ names = []
145
+ values = []
146
+ # sort the keys so that they are consistent across processes
147
+ for k in sorted(input_dict.keys()):
148
+ names.append(k)
149
+ values.append(input_dict[k])
150
+ values = torch.stack(values, dim=0)
151
+ dist.all_reduce(values)
152
+ if average:
153
+ values /= world_size
154
+ reduced_dict = {k: v for k, v in zip(names, values)}
155
+ return reduced_dict
156
+
157
+
158
+ class MetricLogger(object):
159
+ def __init__(self, delimiter="\t"):
160
+ self.meters = defaultdict(SmoothedValue)
161
+ self.delimiter = delimiter
162
+
163
+ def update(self, **kwargs):
164
+ for k, v in kwargs.items():
165
+ if isinstance(v, torch.Tensor):
166
+ v = v.item()
167
+ assert isinstance(v, (float, int))
168
+ self.meters[k].update(v)
169
+
170
+ def __getattr__(self, attr):
171
+ if attr in self.meters:
172
+ return self.meters[attr]
173
+ if attr in self.__dict__:
174
+ return self.__dict__[attr]
175
+ raise AttributeError("'{}' object has no attribute '{}'".format(
176
+ type(self).__name__, attr))
177
+
178
+ def __str__(self):
179
+ loss_str = []
180
+ for name, meter in self.meters.items():
181
+ loss_str.append(
182
+ "{}: {}".format(name, str(meter))
183
+ )
184
+ return self.delimiter.join(loss_str)
185
+
186
+ def synchronize_between_processes(self):
187
+ for meter in self.meters.values():
188
+ meter.synchronize_between_processes()
189
+
190
+ def add_meter(self, name, meter):
191
+ self.meters[name] = meter
192
+
193
+ def log_every(self, iterable, print_freq, header=None):
194
+ i = 0
195
+ if not header:
196
+ header = ''
197
+ start_time = time.time()
198
+ end = time.time()
199
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
200
+ data_time = SmoothedValue(fmt='{avg:.4f}')
201
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
202
+ if torch.cuda.is_available():
203
+ log_msg = self.delimiter.join([
204
+ header,
205
+ '[{0' + space_fmt + '}/{1}]',
206
+ 'eta: {eta}',
207
+ '{meters}',
208
+ 'time: {time}',
209
+ 'data: {data}',
210
+ 'max mem: {memory:.0f}'
211
+ ])
212
+ else:
213
+ log_msg = self.delimiter.join([
214
+ header,
215
+ '[{0' + space_fmt + '}/{1}]',
216
+ 'eta: {eta}',
217
+ '{meters}',
218
+ 'time: {time}',
219
+ 'data: {data}'
220
+ ])
221
+ MB = 1024.0 * 1024.0
222
+ for obj in iterable:
223
+ data_time.update(time.time() - end)
224
+ yield obj
225
+ iter_time.update(time.time() - end)
226
+ if i % print_freq == 0 or i == len(iterable) - 1:
227
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
228
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
229
+ if torch.cuda.is_available():
230
+ print(log_msg.format(
231
+ i, len(iterable), eta=eta_string,
232
+ meters=str(self),
233
+ time=str(iter_time), data=str(data_time),
234
+ memory=torch.cuda.max_memory_allocated() / MB))
235
+ else:
236
+ print(log_msg.format(
237
+ i, len(iterable), eta=eta_string,
238
+ meters=str(self),
239
+ time=str(iter_time), data=str(data_time)))
240
+ i += 1
241
+ end = time.time()
242
+ total_time = time.time() - start_time
243
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
244
+ print('{} Total time: {} ({:.4f} s / it)'.format(
245
+ header, total_time_str, total_time / len(iterable)))
246
+
247
+
248
+ def get_sha():
249
+ cwd = os.path.dirname(os.path.abspath(__file__))
250
+
251
+ def _run(command):
252
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
253
+ sha = 'N/A'
254
+ diff = "clean"
255
+ branch = 'N/A'
256
+ try:
257
+ sha = _run(['git', 'rev-parse', 'HEAD'])
258
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
259
+ diff = _run(['git', 'diff-index', 'HEAD'])
260
+ diff = "has uncommited changes" if diff else "clean"
261
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
262
+ except Exception:
263
+ pass
264
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
265
+ return message
266
+
267
+
268
+ def collate_fn(batch):
269
+ batch = list(zip(*batch))
270
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
271
+ return tuple(batch)
272
+
273
+
274
+ def _max_by_axis(the_list):
275
+ # type: (List[List[int]]) -> List[int]
276
+ maxes = the_list[0]
277
+ for sublist in the_list[1:]:
278
+ for index, item in enumerate(sublist):
279
+ maxes[index] = max(maxes[index], item)
280
+ return maxes
281
+
282
+
283
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
284
+ # TODO make this more general
285
+ if tensor_list[0].ndim == 3:
286
+ if torchvision._is_tracing():
287
+ # nested_tensor_from_tensor_list() does not export well to ONNX
288
+ # call _onnx_nested_tensor_from_tensor_list() instead
289
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
290
+
291
+ # TODO make it support different-sized images
292
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
293
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
294
+ batch_shape = [len(tensor_list)] + max_size
295
+ b, c, h, w = batch_shape
296
+ dtype = tensor_list[0].dtype
297
+ device = tensor_list[0].device
298
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
299
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
300
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
301
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
302
+ m[: img.shape[1], :img.shape[2]] = False
303
+ else:
304
+ raise ValueError('not supported')
305
+ return NestedTensor(tensor, mask)
306
+
307
+
308
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
309
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
310
+ @torch.jit.unused
311
+ def _onnx_nested_tensor_from_tensor_list(tensor_list):
312
+ max_size = []
313
+ for i in range(tensor_list[0].dim()):
314
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
315
+ max_size.append(max_size_i)
316
+ max_size = tuple(max_size)
317
+
318
+ # work around for
319
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
320
+ # m[: img.shape[1], :img.shape[2]] = False
321
+ # which is not yet supported in onnx
322
+ padded_imgs = []
323
+ padded_masks = []
324
+ for img in tensor_list:
325
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
326
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
327
+ padded_imgs.append(padded_img)
328
+
329
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
330
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
331
+ padded_masks.append(padded_mask.to(torch.bool))
332
+
333
+ tensor = torch.stack(padded_imgs)
334
+ mask = torch.stack(padded_masks)
335
+
336
+ return NestedTensor(tensor, mask=mask)
337
+
338
+
339
+ class NestedTensor(object):
340
+ def __init__(self, tensors, mask: Optional[Tensor]):
341
+ self.tensors = tensors
342
+ self.mask = mask
343
+
344
+ def to(self, device):
345
+ # type: (Device) -> NestedTensor # noqa
346
+ cast_tensor = self.tensors.to(device)
347
+ mask = self.mask
348
+ if mask is not None:
349
+ assert mask is not None
350
+ cast_mask = mask.to(device)
351
+ else:
352
+ cast_mask = None
353
+ return NestedTensor(cast_tensor, cast_mask)
354
+
355
+ def decompose(self):
356
+ return self.tensors, self.mask
357
+
358
+ def __repr__(self):
359
+ return str(self.tensors)
360
+
361
+
362
+ def setup_for_distributed(is_master):
363
+ """
364
+ This function disables printing when not in master process
365
+ """
366
+ import builtins as __builtin__
367
+ builtin_print = __builtin__.print
368
+
369
+ def print(*args, **kwargs):
370
+ force = kwargs.pop('force', False)
371
+ if is_master or force:
372
+ builtin_print(*args, **kwargs)
373
+
374
+ __builtin__.print = print
375
+
376
+
377
+ def is_dist_avail_and_initialized():
378
+ if not dist.is_available():
379
+ return False
380
+ if not dist.is_initialized():
381
+ return False
382
+ return True
383
+
384
+
385
+ def get_world_size():
386
+ if not is_dist_avail_and_initialized():
387
+ return 1
388
+ return dist.get_world_size()
389
+
390
+
391
+ def get_rank():
392
+ if not is_dist_avail_and_initialized():
393
+ return 0
394
+ return dist.get_rank()
395
+
396
+
397
+ def is_main_process():
398
+ return get_rank() == 0
399
+
400
+
401
+ def save_on_master(*args, **kwargs):
402
+ if is_main_process():
403
+ torch.save(*args, **kwargs)
404
+
405
+
406
+ def init_distributed_mode(args):
407
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
408
+ args.rank = int(os.environ["RANK"])
409
+ args.world_size = int(os.environ['WORLD_SIZE'])
410
+ args.gpu = int(os.environ['LOCAL_RANK'])
411
+ elif 'SLURM_PROCID' in os.environ:
412
+ args.rank = int(os.environ['SLURM_PROCID'])
413
+ args.gpu = args.rank % torch.cuda.device_count()
414
+ else:
415
+ print('Not using distributed mode')
416
+ args.distributed = False
417
+ return
418
+
419
+ args.distributed = True
420
+
421
+ torch.cuda.set_device(args.gpu)
422
+ args.dist_backend = 'nccl'
423
+ print('| distributed init (rank {}): {}'.format(
424
+ args.rank, args.dist_url), flush=True)
425
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
426
+ world_size=args.world_size, rank=args.rank)
427
+ torch.distributed.barrier()
428
+ setup_for_distributed(args.rank == 0)
429
+
430
+
431
+ @torch.no_grad()
432
+ def accuracy(output, target, topk=(1,)):
433
+ """Computes the precision@k for the specified values of k"""
434
+ if target.numel() == 0:
435
+ return [torch.zeros([], device=output.device)]
436
+ maxk = max(topk)
437
+ batch_size = target.size(0)
438
+
439
+ _, pred = output.topk(maxk, 1, True, True)
440
+ pred = pred.t()
441
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
442
+
443
+ res = []
444
+ for k in topk:
445
+ correct_k = correct[:k].view(-1).float().sum(0)
446
+ res.append(correct_k.mul_(100.0 / batch_size))
447
+ return res
448
+
449
+
450
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
451
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
452
+ """
453
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
454
+ This will eventually be supported natively by PyTorch, and this
455
+ class can go away.
456
+ """
457
+ if float(torchvision.__version__[:3]) < 0.7:
458
+ if input.numel() > 0:
459
+ return torch.nn.functional.interpolate(
460
+ input, size, scale_factor, mode, align_corners
461
+ )
462
+
463
+ output_shape = _output_size(2, input, size, scale_factor)
464
+ output_shape = list(input.shape[:-2]) + list(output_shape)
465
+ return _new_empty_tensor(input, output_shape)
466
+ else:
467
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
models/multi_head_attention.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file provides definition of multi head attention
3
+
4
+ borrowed from https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#MultiheadAttention
5
+ """
6
+ import warnings
7
+ from typing import Tuple, Optional
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ from torch.nn.modules.linear import _LinearWithBias
12
+ from torch.nn.init import xavier_uniform_
13
+ from torch.nn.init import constant_
14
+ from torch.nn.init import xavier_normal_
15
+ from torch.nn.parameter import Parameter
16
+ from torch.nn.modules.module import Module
17
+ from torch.nn import functional as F
18
+ from torch.overrides import has_torch_function, handle_torch_function
19
+ from torch import _VF
20
+
21
+ # Activation functions
22
+ def dropout(input, p=0.5, training=True, inplace=False):
23
+ # type: (Tensor, float, bool, bool) -> Tensor
24
+ r"""
25
+ During training, randomly zeroes some of the elements of the input
26
+ tensor with probability :attr:`p` using samples from a Bernoulli
27
+ distribution.
28
+
29
+ See :class:`~torch.nn.Dropout` for details.
30
+
31
+ Args:
32
+ p: probability of an element to be zeroed. Default: 0.5
33
+ training: apply dropout if is ``True``. Default: ``True``
34
+ inplace: If set to ``True``, will do this operation in-place. Default: ``False``
35
+ """
36
+ if not torch.jit.is_scripting():
37
+ if type(input) is not Tensor and has_torch_function((input,)):
38
+ return handle_torch_function(
39
+ dropout, (input,), input, p=p, training=training, inplace=inplace)
40
+ if p < 0. or p > 1.:
41
+ raise ValueError("dropout probability has to be between 0 and 1, "
42
+ "but got {}".format(p))
43
+ return (_VF.dropout_(input, p, training)
44
+ if inplace
45
+ else _VF.dropout(input, p, training))
46
+
47
+
48
+ def _get_softmax_dim(name, ndim, stacklevel):
49
+ # type: (str, int, int) -> int
50
+ warnings.warn("Implicit dimension choice for {} has been deprecated. "
51
+ "Change the call to include dim=X as an argument.".format(name), stacklevel=stacklevel)
52
+ if ndim == 0 or ndim == 1 or ndim == 3:
53
+ ret = 0
54
+ else:
55
+ ret = 1
56
+ return ret
57
+
58
+ def softmax(input, dim=None, _stacklevel=3, dtype=None):
59
+ # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
60
+ r"""Applies a softmax function.
61
+ Softmax is defined as:
62
+ :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}`
63
+ It is applied to all slices along dim, and will re-scale them so that the elements
64
+ lie in the range `[0, 1]` and sum to 1.
65
+ See :class:`~torch.nn.Softmax` for more details.
66
+ Args:
67
+ input (Tensor): input
68
+ dim (int): A dimension along which softmax will be computed.
69
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
70
+ If specified, the input tensor is casted to :attr:`dtype` before the operation
71
+ is performed. This is useful for preventing data type overflows. Default: None.
72
+ .. note::
73
+ This function doesn't work directly with NLLLoss,
74
+ which expects the Log to be computed between the Softmax and itself.
75
+ Use log_softmax instead (it's faster and has better numerical properties).
76
+ """
77
+ if not torch.jit.is_scripting():
78
+ if type(input) is not Tensor and has_torch_function((input,)):
79
+ return handle_torch_function(
80
+ softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
81
+ if dim is None:
82
+ dim = _get_softmax_dim('softmax', input.dim(), _stacklevel)
83
+ if dtype is None:
84
+ ret = input.softmax(dim)
85
+ else:
86
+ ret = input.softmax(dim, dtype=dtype)
87
+ return ret
88
+
89
+
90
+ def linear(input, weight, bias=None):
91
+ # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
92
+ r"""
93
+ Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
94
+ This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
95
+ Shape:
96
+ - Input: :math:`(N, *, in\_features)` N is the batch size, `*` means any number of
97
+ additional dimensions
98
+ - Weight: :math:`(out\_features, in\_features)`
99
+ - Bias: :math:`(out\_features)`
100
+ - Output: :math:`(N, *, out\_features)`
101
+ """
102
+ tens_ops = (input, weight)
103
+ if not torch.jit.is_scripting():
104
+ if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
105
+ return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
106
+ if input.dim() == 2 and bias is not None:
107
+ # fused op is marginally faster
108
+ ret = torch.addmm(bias, input, weight.t())
109
+ else:
110
+ output = input.matmul(weight.t())
111
+ if bias is not None:
112
+ output += bias
113
+ ret = output
114
+ return ret
115
+
116
+ def multi_head_attention_forward(query: Tensor, key: Tensor, value: Tensor, embed_dim_to_check: int,num_heads: int,
117
+ in_proj_weight: Tensor, in_proj_bias: Tensor, bias_k: Optional[Tensor], bias_v: Optional[Tensor], add_zero_attn: bool,
118
+ dropout_p: float, out_proj_weight: Tensor, out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None,
119
+ need_weights: bool = True, attn_mask: Optional[Tensor] = None, use_separate_proj_weight: bool = False, q_proj_weight: Optional[Tensor] = None,
120
+ k_proj_weight: Optional[Tensor] = None, v_proj_weight: Optional[Tensor] = None, static_k: Optional[Tensor] = None, static_v: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
121
+ r"""
122
+ Args:
123
+ query, key, value: map a query and a set of key-value pairs to an output.
124
+ See "Attention Is All You Need" for more details.
125
+ embed_dim_to_check: total dimension of the model.
126
+ num_heads: parallel attention heads.
127
+ in_proj_weight, in_proj_bias: input projection weight and bias.
128
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
129
+ add_zero_attn: add a new batch of zeros to the key and
130
+ value sequences at dim=1.
131
+ dropout_p: probability of an element to be zeroed.
132
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
133
+ training: apply dropout if is ``True``.
134
+ key_padding_mask: if provided, specified padding elements in the key will
135
+ be ignored by the attention. This is an binary mask. When the value is True,
136
+ the corresponding value on the attention layer will be filled with -inf.
137
+ need_weights: output attn_output_weights.
138
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
139
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
140
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
141
+ and value in different forms. If false, in_proj_weight will be used, which is
142
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
143
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
144
+ static_k, static_v: static key and value used for attention operators.
145
+ Shape:
146
+ Inputs:
147
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
148
+ the embedding dimension.
149
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
150
+ the embedding dimension.
151
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
152
+ the embedding dimension.
153
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
154
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
155
+ will be unchanged. If a BoolTensor is provided, the positions with the
156
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
157
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
158
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
159
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
160
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
161
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
162
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
163
+ is provided, it will be added to the attention weight.
164
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
165
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
166
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
167
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
168
+ Outputs:
169
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
170
+ E is the embedding dimension.
171
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
172
+ L is the target sequence length, S is the source sequence length.
173
+ """
174
+ if not torch.jit.is_scripting():
175
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
176
+ out_proj_weight, out_proj_bias)
177
+ if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
178
+ return handle_torch_function(
179
+ multi_head_attention_forward, tens_ops, query, key, value,
180
+ embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
181
+ bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
182
+ out_proj_bias, training=training, key_padding_mask=key_padding_mask,
183
+ need_weights=need_weights, attn_mask=attn_mask,
184
+ use_separate_proj_weight=use_separate_proj_weight,
185
+ q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
186
+ v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
187
+ tgt_len, bsz, embed_dim = query.size()
188
+ assert embed_dim == embed_dim_to_check
189
+ # allow MHA to have different sizes for the feature dimension
190
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
191
+
192
+ head_dim = embed_dim // num_heads
193
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
194
+ scaling = float(head_dim) ** -0.5
195
+
196
+ if not use_separate_proj_weight:
197
+ if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):
198
+ # self-attention
199
+ q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
200
+
201
+ elif (key is value or torch.equal(key, value)):
202
+ # encoder-decoder attention
203
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
204
+ _b = in_proj_bias
205
+ _start = 0
206
+ _end = embed_dim
207
+ _w = in_proj_weight[_start:_end, :]
208
+ if _b is not None:
209
+ _b = _b[_start:_end]
210
+ q = linear(query, _w, _b)
211
+
212
+ if key is None:
213
+ assert value is None
214
+ k = None
215
+ v = None
216
+ else:
217
+
218
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
219
+ _b = in_proj_bias
220
+ _start = embed_dim
221
+ _end = None
222
+ _w = in_proj_weight[_start:, :]
223
+ if _b is not None:
224
+ _b = _b[_start:]
225
+ k, v = linear(key, _w, _b).chunk(2, dim=-1)
226
+
227
+ else:
228
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
229
+ _b = in_proj_bias
230
+ _start = 0
231
+ _end = embed_dim
232
+ _w = in_proj_weight[_start:_end, :]
233
+ if _b is not None:
234
+ _b = _b[_start:_end]
235
+ q = linear(query, _w, _b)
236
+
237
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
238
+ _b = in_proj_bias
239
+ _start = embed_dim
240
+ _end = embed_dim * 2
241
+ _w = in_proj_weight[_start:_end, :]
242
+ if _b is not None:
243
+ _b = _b[_start:_end]
244
+ k = linear(key, _w, _b)
245
+
246
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
247
+ _b = in_proj_bias
248
+ _start = embed_dim * 2
249
+ _end = None
250
+ _w = in_proj_weight[_start:, :]
251
+ if _b is not None:
252
+ _b = _b[_start:]
253
+ v = linear(value, _w, _b)
254
+ else:
255
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
256
+ len1, len2 = q_proj_weight_non_opt.size()
257
+ assert len1 == embed_dim and len2 == query.size(-1)
258
+
259
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
260
+ len1, len2 = k_proj_weight_non_opt.size()
261
+ assert len1 == embed_dim and len2 == key.size(-1)
262
+
263
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
264
+ len1, len2 = v_proj_weight_non_opt.size()
265
+ assert len1 == embed_dim and len2 == value.size(-1)
266
+
267
+ if in_proj_bias is not None:
268
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
269
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
270
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
271
+ else:
272
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias)
273
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias)
274
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias)
275
+ q = q * scaling
276
+
277
+ if attn_mask is not None:
278
+ assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
279
+ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
280
+ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
281
+ if attn_mask.dtype == torch.uint8:
282
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
283
+ attn_mask = attn_mask.to(torch.bool)
284
+
285
+ if attn_mask.dim() == 2:
286
+ attn_mask = attn_mask.unsqueeze(0)
287
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
288
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
289
+ elif attn_mask.dim() == 3:
290
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
291
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
292
+ else:
293
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
294
+ # attn_mask's dim is 3 now.
295
+
296
+ # convert ByteTensor key_padding_mask to bool
297
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
298
+ warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
299
+ key_padding_mask = key_padding_mask.to(torch.bool)
300
+
301
+ if bias_k is not None and bias_v is not None:
302
+ if static_k is None and static_v is None:
303
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
304
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
305
+ if attn_mask is not None:
306
+ attn_mask = pad(attn_mask, (0, 1))
307
+ if key_padding_mask is not None:
308
+ key_padding_mask = pad(key_padding_mask, (0, 1))
309
+ else:
310
+ assert static_k is None, "bias cannot be added to static key."
311
+ assert static_v is None, "bias cannot be added to static value."
312
+ else:
313
+ assert bias_k is None
314
+ assert bias_v is None
315
+
316
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
317
+ if k is not None:
318
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
319
+ if v is not None:
320
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
321
+
322
+ if static_k is not None:
323
+ assert static_k.size(0) == bsz * num_heads
324
+ assert static_k.size(2) == head_dim
325
+ k = static_k
326
+
327
+ if static_v is not None:
328
+ assert static_v.size(0) == bsz * num_heads
329
+ assert static_v.size(2) == head_dim
330
+ v = static_v
331
+
332
+ src_len = k.size(1)
333
+
334
+ if key_padding_mask is not None:
335
+ assert key_padding_mask.size(0) == bsz
336
+ assert key_padding_mask.size(1) == src_len
337
+
338
+ if add_zero_attn:
339
+ src_len += 1
340
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
341
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
342
+ if attn_mask is not None:
343
+ attn_mask = pad(attn_mask, (0, 1))
344
+ if key_padding_mask is not None:
345
+ key_padding_mask = pad(key_padding_mask, (0, 1))
346
+
347
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
348
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
349
+
350
+ if attn_mask is not None:
351
+ if attn_mask.dtype == torch.bool:
352
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
353
+ else:
354
+ attn_output_weights += attn_mask
355
+
356
+
357
+ if key_padding_mask is not None:
358
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
359
+ attn_output_weights = attn_output_weights.masked_fill(
360
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
361
+ float('-inf'),
362
+ )
363
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
364
+
365
+ attn_output_weights = softmax(
366
+ attn_output_weights, dim=-1)
367
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
368
+
369
+ attn_output = torch.bmm(attn_output_weights, v)
370
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
371
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
372
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
373
+
374
+ if need_weights:
375
+ # average attention weights over heads
376
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
377
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
378
+ else:
379
+ return attn_output, None
380
+
381
+ class MultiheadAttention(Module):
382
+ r"""Allows the model to jointly attend to information
383
+ from different representation subspaces.
384
+ See reference: Attention Is All You Need
385
+
386
+ .. math::
387
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
388
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
389
+
390
+ Args:
391
+ embed_dim: total dimension of the model.
392
+ num_heads: parallel attention heads.
393
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
394
+ bias: add bias as module parameter. Default: True.
395
+ add_bias_kv: add bias to the key and value sequences at dim=0.
396
+ add_zero_attn: add a new batch of zeros to the key and
397
+ value sequences at dim=1.
398
+ kdim: total number of features in key. Default: None.
399
+ vdim: total number of features in value. Default: None.
400
+
401
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
402
+ query, key, and value have the same number of features.
403
+
404
+ Examples::
405
+
406
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
407
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
408
+ """
409
+ bias_k: Optional[torch.Tensor]
410
+ bias_v: Optional[torch.Tensor]
411
+
412
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
413
+ super(MultiheadAttention, self).__init__()
414
+ self.embed_dim = embed_dim
415
+ self.kdim = kdim if kdim is not None else embed_dim
416
+ self.vdim = vdim if vdim is not None else embed_dim
417
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
418
+
419
+ self.num_heads = num_heads
420
+ self.dropout = dropout
421
+ self.head_dim = embed_dim // num_heads
422
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
423
+
424
+ if self._qkv_same_embed_dim is False:
425
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
426
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
427
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
428
+ self.register_parameter('in_proj_weight', None)
429
+ else:
430
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
431
+ self.register_parameter('q_proj_weight', None)
432
+ self.register_parameter('k_proj_weight', None)
433
+ self.register_parameter('v_proj_weight', None)
434
+
435
+ if bias:
436
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
437
+ else:
438
+ self.register_parameter('in_proj_bias', None)
439
+ self.out_proj = _LinearWithBias(embed_dim, embed_dim)
440
+
441
+ if add_bias_kv:
442
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
443
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
444
+ else:
445
+ self.bias_k = self.bias_v = None
446
+
447
+ self.add_zero_attn = add_zero_attn
448
+
449
+ self._reset_parameters()
450
+
451
+ def _reset_parameters(self):
452
+ if self._qkv_same_embed_dim:
453
+ xavier_uniform_(self.in_proj_weight)
454
+ else:
455
+ xavier_uniform_(self.q_proj_weight)
456
+ xavier_uniform_(self.k_proj_weight)
457
+ xavier_uniform_(self.v_proj_weight)
458
+
459
+ if self.in_proj_bias is not None:
460
+ constant_(self.in_proj_bias, 0.)
461
+ constant_(self.out_proj.bias, 0.)
462
+ if self.bias_k is not None:
463
+ xavier_normal_(self.bias_k)
464
+ if self.bias_v is not None:
465
+ xavier_normal_(self.bias_v)
466
+
467
+ def __setstate__(self, state):
468
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
469
+ if '_qkv_same_embed_dim' not in state:
470
+ state['_qkv_same_embed_dim'] = True
471
+
472
+ super(MultiheadAttention, self).__setstate__(state)
473
+
474
+ def forward(self, query, key, value, key_padding_mask=None,
475
+ need_weights=True, attn_mask=None):
476
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
477
+ r"""
478
+ Args:
479
+ query, key, value: map a query and a set of key-value pairs to an output.
480
+ See "Attention Is All You Need" for more details.
481
+ key_padding_mask: if provided, specified padding elements in the key will
482
+ be ignored by the attention. When given a binary mask and a value is True,
483
+ the corresponding value on the attention layer will be ignored. When given
484
+ a byte mask and a value is non-zero, the corresponding value on the attention
485
+ layer will be ignored
486
+ need_weights: output attn_output_weights.
487
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
488
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
489
+
490
+ Shape:
491
+ - Inputs:
492
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
493
+ the embedding dimension.
494
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
495
+ the embedding dimension.
496
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
497
+ the embedding dimension.
498
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
499
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
500
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
501
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
502
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
503
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
504
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
505
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
506
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
507
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
508
+ is provided, it will be added to the attention weight.
509
+
510
+ - Outputs:
511
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
512
+ E is the embedding dimension.
513
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
514
+ L is the target sequence length, S is the source sequence length.
515
+ """
516
+ if not self._qkv_same_embed_dim:
517
+ return multi_head_attention_forward(
518
+ query, key, value, self.embed_dim, self.num_heads,
519
+ self.in_proj_weight, self.in_proj_bias,
520
+ self.bias_k, self.bias_v, self.add_zero_attn,
521
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
522
+ training=self.training,
523
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
524
+ attn_mask=attn_mask, use_separate_proj_weight=True,
525
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
526
+ v_proj_weight=self.v_proj_weight)
527
+ else:
528
+ return multi_head_attention_forward(
529
+ query, key, value, self.embed_dim, self.num_heads,
530
+ self.in_proj_weight, self.in_proj_bias,
531
+ self.bias_k, self.bias_v, self.add_zero_attn,
532
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
533
+ training=self.training,
534
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
535
+ attn_mask=attn_mask)
536
+
537
+
models/position_encoding.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various positional encodings for the transformer.
3
+ borrowed from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+
9
+ from .misc import NestedTensor
10
+
11
+
12
+ class PositionEmbeddingSine(nn.Module):
13
+ """
14
+ This is a more standard version of the position embedding, very similar to the one
15
+ used by the Attention is all you need paper, generalized to work on images.
16
+ """
17
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
18
+ super().__init__()
19
+ self.num_pos_feats = num_pos_feats
20
+ self.temperature = temperature
21
+ self.normalize = normalize
22
+ if scale is not None and normalize is False:
23
+ raise ValueError("normalize should be True if scale is passed")
24
+ if scale is None:
25
+ scale = 2 * math.pi
26
+ self.scale = scale
27
+
28
+ def forward(self, tensor_list: NestedTensor):
29
+ x = tensor_list.tensors
30
+ mask = tensor_list.mask
31
+ assert mask is not None
32
+ not_mask = ~mask
33
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
34
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
35
+ if self.normalize:
36
+ eps = 1e-6
37
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39
+
40
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42
+
43
+ pos_x = x_embed[:, :, :, None] / dim_t
44
+ pos_y = y_embed[:, :, :, None] / dim_t
45
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
46
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
47
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
48
+ return pos
49
+
50
+
51
+ class PositionEmbeddingLearned(nn.Module):
52
+ """
53
+ Absolute pos embedding, learned.
54
+ """
55
+ def __init__(self, num_pos_feats=256):
56
+ super().__init__()
57
+ self.row_embed = nn.Embedding(50, num_pos_feats)
58
+ self.col_embed = nn.Embedding(50, num_pos_feats)
59
+ self.reset_parameters()
60
+
61
+ def reset_parameters(self):
62
+ nn.init.uniform_(self.row_embed.weight)
63
+ nn.init.uniform_(self.col_embed.weight)
64
+
65
+ def forward(self, tensor_list: NestedTensor):
66
+ x = tensor_list.tensors
67
+ h, w = x.shape[-2:]
68
+ i = torch.arange(w, device=x.device)
69
+ j = torch.arange(h, device=x.device)
70
+ x_emb = self.col_embed(i)
71
+ y_emb = self.row_embed(j)
72
+ pos = torch.cat([
73
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
74
+ y_emb.unsqueeze(1).repeat(1, w, 1),
75
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
76
+ return pos
77
+
78
+
79
+ def build_position_encoding(args):
80
+ N_steps = args.hidden_dim // 2
81
+ if args.position_embedding in ('v2', 'sine'):
82
+ # TODO find a better way of exposing other arguments
83
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
84
+ elif args.position_embedding in ('v3', 'learned'):
85
+ position_embedding = PositionEmbeddingLearned(N_steps)
86
+ else:
87
+ raise ValueError(f"not supported {args.position_embedding}")
88
+
89
+ return position_embedding
models/preprocessing.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms.functional as functional
2
+
3
+ class Compose(object):
4
+ def __init__(self, transforms):
5
+ self.transforms = transforms
6
+
7
+ def __call__(self, image):
8
+ for t in self.transforms:
9
+ image = t(image)
10
+ return image
11
+
12
+ def __repr__(self):
13
+ format_string = self.__class__.__name__ + "("
14
+ for t in self.transforms:
15
+ format_string += "\n"
16
+ format_string += " {0}".format(t)
17
+ format_string += "\n)"
18
+ return format_string
19
+
20
+ class Normalize(object):
21
+ def __init__(self, mean, std):
22
+ self.mean = mean
23
+ self.std = std
24
+
25
+ def __call__(self, image):
26
+ image = functional.normalize(image, mean=self.mean, std=self.std)
27
+ return image
28
+
29
+ class ToTensor(object):
30
+ def __call__(self, img):
31
+ return functional.to_tensor(img)
32
+
33
+ def resize(image, size, max_size=None):
34
+ # size can be min_size (scalar) or (w, h) tuple
35
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
36
+ w, h = image_size
37
+ if max_size is not None:
38
+ min_original_size = float(min((w, h)))
39
+ max_original_size = float(max((w, h)))
40
+ if max_original_size / min_original_size * size > max_size:
41
+ size = int(round(max_size * min_original_size / max_original_size))
42
+ if (w <= h and w == size) or (h <= w and h == size):
43
+ return (h, w)
44
+ if w < h:
45
+ ow = size
46
+ oh = int(size * h / w)
47
+ else:
48
+ oh = size
49
+ ow = int(size * w / h)
50
+ return (oh, ow)
51
+
52
+ def get_size(image_size, size, max_size=None):
53
+ if isinstance(size, (list, tuple)):
54
+ return size[::-1]
55
+ else:
56
+ return get_size_with_aspect_ratio(image_size, size, max_size)
57
+
58
+ size = get_size(image.size, size, max_size)
59
+ rescaled_image = functional.resize(image, size)
60
+
61
+ return rescaled_image
62
+
63
+ class Resize(object):
64
+ def __init__(self, sizes, max_size=None):
65
+ assert isinstance(sizes, (list, tuple))
66
+ self.sizes = sizes
67
+ self.max_size = max_size
68
+
69
+ def __call__(self, img):
70
+ size = self.sizes
71
+ return resize(img, size, self.max_size)
models/transformer.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ DETR Transformer class.
4
+
5
+ Copy-paste from torch.nn.Transformer with modifications:
6
+ * positional encodings are passed in MHattention
7
+ * extra LN at the end of encoder is removed
8
+ * decoder returns a stack of activations from all decoding layers
9
+ """
10
+ import copy
11
+ from typing import Optional, List
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import nn, Tensor
16
+ from .multi_head_attention import MultiheadAttention
17
+
18
+ class Transformer(nn.Module):
19
+
20
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
21
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
22
+ activation="relu", normalize_before=False,
23
+ return_intermediate_dec=False):
24
+ super().__init__()
25
+
26
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
27
+ dropout, activation, normalize_before)
28
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
29
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
30
+
31
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
32
+ dropout, activation, normalize_before)
33
+ decoder_norm = nn.LayerNorm(d_model)
34
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
35
+ return_intermediate=return_intermediate_dec)
36
+
37
+ self._reset_parameters()
38
+
39
+ self.d_model = d_model
40
+ self.nhead = nhead
41
+
42
+ def _reset_parameters(self):
43
+ for p in self.parameters():
44
+ if p.dim() > 1:
45
+ nn.init.xavier_uniform_(p)
46
+
47
+ def forward(self, src, mask, query_embed, pos_embed):
48
+ # flatten NxCxHxW to HWxNxC
49
+ bs, c, h, w = src.shape
50
+ src = src.flatten(2).permute(2, 0, 1)
51
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
52
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
53
+ mask = mask.flatten(1)
54
+
55
+ tgt = torch.zeros_like(query_embed)
56
+
57
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
58
+
59
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
60
+ pos=pos_embed, query_pos=query_embed)
61
+ return hs.transpose(1, 2), memory#.permute(1, 2, 0).view(bs, c, h, w)
62
+
63
+
64
+ class TransformerEncoder(nn.Module):
65
+
66
+ def __init__(self, encoder_layer, num_layers, norm=None):
67
+ super().__init__()
68
+ self.layers = _get_clones(encoder_layer, num_layers)
69
+ self.num_layers = num_layers
70
+ self.norm = norm
71
+
72
+ def forward(self, src,
73
+ mask: Optional[Tensor] = None,
74
+ src_key_padding_mask: Optional[Tensor] = None,
75
+ pos: Optional[Tensor] = None):
76
+ output = src
77
+
78
+ for layer in self.layers:
79
+ output = layer(output, src_mask=mask,
80
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
81
+
82
+ if self.norm is not None:
83
+ output = self.norm(output)
84
+
85
+ return output
86
+
87
+ class TransformerDecoder(nn.Module):
88
+
89
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
90
+ super().__init__()
91
+ self.layers = _get_clones(decoder_layer, num_layers)
92
+ self.num_layers = num_layers
93
+ self.norm = norm
94
+ self.return_intermediate = return_intermediate
95
+
96
+ def forward(self, tgt, memory,
97
+ tgt_mask: Optional[Tensor] = None,
98
+ memory_mask: Optional[Tensor] = None,
99
+ tgt_key_padding_mask: Optional[Tensor] = None,
100
+ memory_key_padding_mask: Optional[Tensor] = None,
101
+ pos: Optional[Tensor] = None,
102
+ query_pos: Optional[Tensor] = None):
103
+ output = tgt
104
+
105
+ intermediate = []
106
+
107
+ for layer in self.layers:
108
+ output = layer(output, memory, tgt_mask=tgt_mask,
109
+ memory_mask=memory_mask,
110
+ tgt_key_padding_mask=tgt_key_padding_mask,
111
+ memory_key_padding_mask=memory_key_padding_mask,
112
+ pos=pos, query_pos=query_pos)
113
+ if self.return_intermediate:
114
+ intermediate.append(self.norm(output))
115
+
116
+ if self.norm is not None:
117
+ output = self.norm(output)
118
+ if self.return_intermediate:
119
+ intermediate.pop()
120
+ intermediate.append(output)
121
+
122
+ if self.return_intermediate:
123
+ return torch.stack(intermediate)
124
+
125
+ return output.unsqueeze(0)
126
+
127
+
128
+ class TransformerEncoderLayer(nn.Module):
129
+
130
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
131
+ activation="relu", normalize_before=False):
132
+ super().__init__()
133
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
134
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
135
+ self.dropout = nn.Dropout(dropout)
136
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
137
+
138
+ self.norm1 = nn.LayerNorm(d_model)
139
+ self.norm2 = nn.LayerNorm(d_model)
140
+ self.dropout1 = nn.Dropout(dropout)
141
+ self.dropout2 = nn.Dropout(dropout)
142
+
143
+ self.activation = _get_activation_fn(activation)
144
+ self.normalize_before = normalize_before
145
+
146
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
147
+ return tensor if pos is None else tensor + pos
148
+
149
+ def forward_post(self,
150
+ src,
151
+ src_mask: Optional[Tensor] = None,
152
+ src_key_padding_mask: Optional[Tensor] = None,
153
+ pos: Optional[Tensor] = None):
154
+ q = k = self.with_pos_embed(src, pos)
155
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
156
+ key_padding_mask=src_key_padding_mask)[0]
157
+ src = src + self.dropout1(src2)
158
+ src = self.norm1(src)
159
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
160
+ src = src + self.dropout2(src2)
161
+ src = self.norm2(src)
162
+ return src
163
+
164
+ def forward_pre(self, src,
165
+ src_mask: Optional[Tensor] = None,
166
+ src_key_padding_mask: Optional[Tensor] = None,
167
+ pos: Optional[Tensor] = None):
168
+ src2 = self.norm1(src)
169
+ q = k = self.with_pos_embed(src2, pos)
170
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
171
+ key_padding_mask=src_key_padding_mask)[0]
172
+ src = src + self.dropout1(src2)
173
+ src2 = self.norm2(src)
174
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
175
+ src = src + self.dropout2(src2)
176
+ return src
177
+
178
+ def forward(self, src,
179
+ src_mask: Optional[Tensor] = None,
180
+ src_key_padding_mask: Optional[Tensor] = None,
181
+ pos: Optional[Tensor] = None):
182
+ if self.normalize_before:
183
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
184
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
185
+
186
+
187
+ class TransformerDecoderLayer(nn.Module):
188
+
189
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
190
+ activation="relu", normalize_before=False):
191
+ super().__init__()
192
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
193
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
194
+ # Implementation of Feedforward model
195
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
196
+ self.dropout = nn.Dropout(dropout)
197
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
198
+
199
+ self.norm1 = nn.LayerNorm(d_model)
200
+ self.norm2 = nn.LayerNorm(d_model)
201
+ self.norm3 = nn.LayerNorm(d_model)
202
+ self.dropout1 = nn.Dropout(dropout)
203
+ self.dropout2 = nn.Dropout(dropout)
204
+ self.dropout3 = nn.Dropout(dropout)
205
+
206
+ self.activation = _get_activation_fn(activation)
207
+ self.normalize_before = normalize_before
208
+
209
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
210
+ return tensor if pos is None else tensor + pos
211
+
212
+ def forward_post(self, tgt, memory,
213
+ tgt_mask: Optional[Tensor] = None,
214
+ memory_mask: Optional[Tensor] = None,
215
+ tgt_key_padding_mask: Optional[Tensor] = None,
216
+ memory_key_padding_mask: Optional[Tensor] = None,
217
+ pos: Optional[Tensor] = None,
218
+ query_pos: Optional[Tensor] = None):
219
+ q = k = self.with_pos_embed(tgt, query_pos)
220
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
221
+ key_padding_mask=tgt_key_padding_mask)[0]
222
+ tgt = tgt + self.dropout1(tgt2)
223
+ tgt = self.norm1(tgt)
224
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
225
+ key=self.with_pos_embed(memory, pos),
226
+ value=memory, attn_mask=memory_mask,
227
+ key_padding_mask=memory_key_padding_mask)[0]
228
+ tgt = tgt + self.dropout2(tgt2)
229
+ tgt = self.norm2(tgt)
230
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
231
+ tgt = tgt + self.dropout3(tgt2)
232
+ tgt = self.norm3(tgt)
233
+ return tgt
234
+
235
+ def forward_pre(self, tgt, memory,
236
+ tgt_mask: Optional[Tensor] = None,
237
+ memory_mask: Optional[Tensor] = None,
238
+ tgt_key_padding_mask: Optional[Tensor] = None,
239
+ memory_key_padding_mask: Optional[Tensor] = None,
240
+ pos: Optional[Tensor] = None,
241
+ query_pos: Optional[Tensor] = None):
242
+ tgt2 = self.norm1(tgt)
243
+ q = k = self.with_pos_embed(tgt2, query_pos)
244
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
245
+ key_padding_mask=tgt_key_padding_mask)[0]
246
+ tgt = tgt + self.dropout1(tgt2)
247
+ tgt2 = self.norm2(tgt)
248
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
249
+ key=self.with_pos_embed(memory, pos),
250
+ value=memory, attn_mask=memory_mask,
251
+ key_padding_mask=memory_key_padding_mask)[0]
252
+ tgt = tgt + self.dropout2(tgt2)
253
+ tgt2 = self.norm3(tgt)
254
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
255
+ tgt = tgt + self.dropout3(tgt2)
256
+ return tgt
257
+
258
+ def forward(self, tgt, memory,
259
+ tgt_mask: Optional[Tensor] = None,
260
+ memory_mask: Optional[Tensor] = None,
261
+ tgt_key_padding_mask: Optional[Tensor] = None,
262
+ memory_key_padding_mask: Optional[Tensor] = None,
263
+ pos: Optional[Tensor] = None,
264
+ query_pos: Optional[Tensor] = None):
265
+ if self.normalize_before:
266
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
267
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
268
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
269
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
270
+
271
+
272
+ def _get_clones(module, N):
273
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
274
+
275
+
276
+ def build_transformer(args):
277
+
278
+ return Transformer(
279
+ d_model=args.hidden_dim,
280
+ dropout=args.dropout,
281
+ nhead=args.nheads,
282
+ dim_feedforward=args.dim_feedforward,
283
+ num_encoder_layers=args.enc_layers,
284
+ num_decoder_layers=args.dec_layers,
285
+ normalize_before=args.pre_norm,
286
+ return_intermediate_dec=True,
287
+ )
288
+
289
+ def _get_activation_fn(activation):
290
+ """Return an activation function given a string"""
291
+ if activation == "relu":
292
+ return F.relu
293
+ if activation == "gelu":
294
+ return F.gelu
295
+ if activation == "glu":
296
+ return F.glu
297
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.8.1
2
+ torchvision
3
+ gradio
4
+ jinja2
5
+ scipy
tappeto-per-calibrazione.jpg ADDED
test.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+
3
+ import torch.nn.functional as F
4
+ import torch
5
+
6
+ from models.letr import build
7
+ from models.misc import nested_tensor_from_tensor_list
8
+ from models.preprocessing import Compose, ToTensor, Resize, Normalize
9
+
10
+ def create_letr():
11
+ # obtain checkpoints
12
+ checkpoint = torch.load('checkpoint0024.pth', map_location='cpu')
13
+
14
+ # load model
15
+ args = checkpoint['args']
16
+ args.device = 'cpu'
17
+ model, _, _ = build(args)
18
+ model.load_state_dict(checkpoint['model'])
19
+ model.eval()
20
+ return model
21
+
22
+ def draw_fig(image, outputs, orig_size):
23
+ # find lines
24
+ out_logits, out_line = outputs['pred_logits'], outputs['pred_lines']
25
+ prob = F.softmax(out_logits, -1)
26
+ scores, labels = prob[..., :-1].max(-1)
27
+ img_h, img_w = orig_size.unbind(0)
28
+ scale_fct = torch.unsqueeze(torch.stack(
29
+ [img_w, img_h, img_w, img_h], dim=0), dim=0)
30
+ lines = out_line * scale_fct[:, None, :]
31
+ lines = lines.view(1000, 2, 2)
32
+ lines = lines.flip([-1]) # this is yxyx format
33
+ scores = scores.detach().numpy()
34
+ keep = scores >= 0.7
35
+ keep = keep.squeeze()
36
+ lines = lines[keep]
37
+ if len(lines) != 0:
38
+ lines = lines.reshape(lines.shape[0], -1)
39
+
40
+ # draw lines
41
+ draw = ImageDraw.Draw(image)
42
+ for tp_id, line in enumerate(lines):
43
+ y1, x1, y2, x2 = line
44
+ draw.line((x1, y1, x2, y2), fill=500)
45
+
46
+ if __name__ == '__main__':
47
+ model = create_letr()
48
+
49
+ test_size = 256
50
+ normalize = Compose([
51
+ ToTensor(),
52
+ Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
53
+ Resize([test_size]),
54
+ ])
55
+
56
+ image = Image.open('demo.png')
57
+ h, w = image.height, image.width
58
+ orig_size = torch.as_tensor([int(h), int(w)])
59
+
60
+ img = normalize(image)
61
+ inputs = nested_tensor_from_tensor_list([img])
62
+
63
+ with torch.no_grad():
64
+ outputs = model(inputs)[0]
65
+ draw_fig(image, outputs, orig_size)
66
+
67
+ image.save('output.png')