Spaces:
Running
Running
demo letr
Browse files- app.py +57 -0
- checkpoint0024.pth +3 -0
- demo.png +0 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/backbone.cpython-38.pyc +0 -0
- models/__pycache__/letr.cpython-38.pyc +0 -0
- models/__pycache__/letr_stack.cpython-38.pyc +0 -0
- models/__pycache__/matcher.cpython-38.pyc +0 -0
- models/__pycache__/misc.cpython-38.pyc +0 -0
- models/__pycache__/multi_head_attention.cpython-38.pyc +0 -0
- models/__pycache__/position_encoding.cpython-38.pyc +0 -0
- models/__pycache__/preprocessing.cpython-38.pyc +0 -0
- models/__pycache__/transformer.cpython-38.pyc +0 -0
- models/backbone.py +120 -0
- models/letr.py +371 -0
- models/letr_stack.py +376 -0
- models/matcher.py +81 -0
- models/misc.py +467 -0
- models/multi_head_attention.py +537 -0
- models/position_encoding.py +89 -0
- models/preprocessing.py +71 -0
- models/transformer.py +297 -0
- requirements.txt +5 -0
- tappeto-per-calibrazione.jpg +0 -0
- test.py +67 -0
app.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
# import sys
|
10 |
+
# sys.path.insert(0, './')
|
11 |
+
from test import create_letr, draw_fig
|
12 |
+
from models.preprocessing import *
|
13 |
+
from models.misc import nested_tensor_from_tensor_list
|
14 |
+
|
15 |
+
|
16 |
+
model = create_letr()
|
17 |
+
|
18 |
+
# PREPARE PREPROCESSING
|
19 |
+
test_size = 256
|
20 |
+
# transform_test = transforms.Compose([
|
21 |
+
# transforms.Resize((test_size)),
|
22 |
+
# transforms.ToTensor(),
|
23 |
+
# transforms.Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
|
24 |
+
# ])
|
25 |
+
normalize = Compose([
|
26 |
+
ToTensor(),
|
27 |
+
Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
|
28 |
+
Resize([test_size]),
|
29 |
+
])
|
30 |
+
|
31 |
+
|
32 |
+
def predict(inp):
|
33 |
+
image = Image.fromarray(inp.astype('uint8'), 'RGB')
|
34 |
+
h, w = image.height, image.width
|
35 |
+
orig_size = torch.as_tensor([int(h), int(w)])
|
36 |
+
|
37 |
+
img = normalize(image)
|
38 |
+
inputs = nested_tensor_from_tensor_list([img])
|
39 |
+
|
40 |
+
with torch.no_grad():
|
41 |
+
outputs = model(inputs)[0]
|
42 |
+
|
43 |
+
draw_fig(image, outputs, orig_size)
|
44 |
+
|
45 |
+
return image
|
46 |
+
|
47 |
+
|
48 |
+
inputs = gr.inputs.Image()
|
49 |
+
outputs = gr.outputs.Image()
|
50 |
+
gr.Interface(
|
51 |
+
fn=predict,
|
52 |
+
inputs=inputs,
|
53 |
+
outputs=outputs,
|
54 |
+
examples=["demo.png", "tappeto-per-calibrazione.jpg"],
|
55 |
+
title="LETR",
|
56 |
+
description="Model for line detection..."
|
57 |
+
).launch()
|
checkpoint0024.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:26725e48335937731ac968a3fbde602d296ca3edcf93b79f4f76f356ad3a4ff9
|
3 |
+
size 380893769
|
demo.png
ADDED
models/__init__.py
ADDED
File without changes
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (143 Bytes). View file
|
|
models/__pycache__/backbone.cpython-38.pyc
ADDED
Binary file (4.75 kB). View file
|
|
models/__pycache__/letr.cpython-38.pyc
ADDED
Binary file (13.2 kB). View file
|
|
models/__pycache__/letr_stack.cpython-38.pyc
ADDED
Binary file (12.2 kB). View file
|
|
models/__pycache__/matcher.cpython-38.pyc
ADDED
Binary file (4.12 kB). View file
|
|
models/__pycache__/misc.cpython-38.pyc
ADDED
Binary file (14.6 kB). View file
|
|
models/__pycache__/multi_head_attention.cpython-38.pyc
ADDED
Binary file (19.7 kB). View file
|
|
models/__pycache__/position_encoding.cpython-38.pyc
ADDED
Binary file (3.65 kB). View file
|
|
models/__pycache__/preprocessing.cpython-38.pyc
ADDED
Binary file (2.98 kB). View file
|
|
models/__pycache__/transformer.cpython-38.pyc
ADDED
Binary file (9 kB). View file
|
|
models/backbone.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LETR Backbone modules.
|
3 |
+
modified based on https://github.com/facebookresearch/detr/blob/master/models/backbone.py
|
4 |
+
"""
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
from torch import nn
|
11 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
12 |
+
from typing import Dict, List
|
13 |
+
|
14 |
+
from .misc import NestedTensor, is_main_process
|
15 |
+
|
16 |
+
from .position_encoding import build_position_encoding
|
17 |
+
|
18 |
+
|
19 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
20 |
+
"""
|
21 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
22 |
+
|
23 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
24 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
25 |
+
produce nans.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, n):
|
29 |
+
super(FrozenBatchNorm2d, self).__init__()
|
30 |
+
self.register_buffer("weight", torch.ones(n))
|
31 |
+
self.register_buffer("bias", torch.zeros(n))
|
32 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
33 |
+
self.register_buffer("running_var", torch.ones(n))
|
34 |
+
|
35 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
36 |
+
missing_keys, unexpected_keys, error_msgs):
|
37 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
38 |
+
if num_batches_tracked_key in state_dict:
|
39 |
+
del state_dict[num_batches_tracked_key]
|
40 |
+
|
41 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
42 |
+
state_dict, prefix, local_metadata, strict,
|
43 |
+
missing_keys, unexpected_keys, error_msgs)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
# move reshapes to the beginning
|
47 |
+
# to make it fuser-friendly
|
48 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
49 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
50 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
51 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
52 |
+
eps = 1e-5
|
53 |
+
scale = w * (rv + eps).rsqrt()
|
54 |
+
bias = b - rm * scale
|
55 |
+
return x * scale + bias
|
56 |
+
|
57 |
+
|
58 |
+
class BackboneBase(nn.Module):
|
59 |
+
|
60 |
+
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
61 |
+
super().__init__()
|
62 |
+
for name, parameter in backbone.named_parameters():
|
63 |
+
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
64 |
+
parameter.requires_grad_(False)
|
65 |
+
if return_interm_layers:
|
66 |
+
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
67 |
+
else:
|
68 |
+
return_layers = {'layer4': "0"}
|
69 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
70 |
+
self.num_channels = num_channels
|
71 |
+
|
72 |
+
def forward(self, tensor_list: NestedTensor):
|
73 |
+
xs = self.body(tensor_list.tensors)
|
74 |
+
out: Dict[str, NestedTensor] = {}
|
75 |
+
for name, x in xs.items():
|
76 |
+
|
77 |
+
m = tensor_list.mask
|
78 |
+
assert m is not None
|
79 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
80 |
+
out[name] = NestedTensor(x, mask)
|
81 |
+
return out
|
82 |
+
|
83 |
+
|
84 |
+
class Backbone(BackboneBase):
|
85 |
+
"""ResNet backbone with frozen BatchNorm."""
|
86 |
+
def __init__(self, name: str,
|
87 |
+
train_backbone: bool,
|
88 |
+
return_interm_layers: bool,
|
89 |
+
dilation: bool):
|
90 |
+
backbone = getattr(torchvision.models, name)(
|
91 |
+
replace_stride_with_dilation=[False, False, dilation],
|
92 |
+
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
|
93 |
+
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
94 |
+
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
95 |
+
|
96 |
+
|
97 |
+
class Joiner(nn.Sequential):
|
98 |
+
def __init__(self, backbone, position_embedding):
|
99 |
+
super().__init__(backbone, position_embedding)
|
100 |
+
|
101 |
+
def forward(self, tensor_list: NestedTensor):
|
102 |
+
xs = self[0](tensor_list)
|
103 |
+
out: List[NestedTensor] = []
|
104 |
+
pos = []
|
105 |
+
for name, x in xs.items():
|
106 |
+
out.append(x)
|
107 |
+
# position encoding
|
108 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
109 |
+
|
110 |
+
return out, pos
|
111 |
+
|
112 |
+
|
113 |
+
def build_backbone(args):
|
114 |
+
position_embedding = build_position_encoding(args)
|
115 |
+
train_backbone = args.lr_backbone > 0
|
116 |
+
return_interm_layers = True
|
117 |
+
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
118 |
+
model = Joiner(backbone, position_embedding)
|
119 |
+
model.num_channels = backbone.num_channels
|
120 |
+
return model
|
models/letr.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file provides coarse stage LETR definition
|
3 |
+
Modified based on https://github.com/facebookresearch/detr/blob/master/models/backbone.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from .misc import (NestedTensor, nested_tensor_from_tensor_list,
|
10 |
+
accuracy, get_world_size, interpolate,
|
11 |
+
is_dist_avail_and_initialized)
|
12 |
+
|
13 |
+
from .backbone import build_backbone
|
14 |
+
from .matcher import build_matcher
|
15 |
+
from .transformer import build_transformer
|
16 |
+
from .letr_stack import LETRstack
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
class LETR(nn.Module):
|
20 |
+
""" This is the LETR module that performs object detection """
|
21 |
+
def __init__(self, backbone, transformer, num_classes, num_queries, args, aux_loss=False):
|
22 |
+
super().__init__()
|
23 |
+
self.num_queries = num_queries
|
24 |
+
self.transformer = transformer
|
25 |
+
hidden_dim = transformer.d_model
|
26 |
+
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
|
27 |
+
|
28 |
+
self.lines_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
29 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
30 |
+
|
31 |
+
channel = [256, 512, 1024, 2048]
|
32 |
+
self.input_proj = nn.Conv2d(channel[args.layer1_num], hidden_dim, kernel_size=1)
|
33 |
+
|
34 |
+
self.backbone = backbone
|
35 |
+
self.aux_loss = aux_loss
|
36 |
+
self.args = args
|
37 |
+
|
38 |
+
def forward(self, samples, postprocessors=None, targets=None, criterion=None):
|
39 |
+
if isinstance(samples, (list, torch.Tensor)):
|
40 |
+
samples = nested_tensor_from_tensor_list(samples)
|
41 |
+
|
42 |
+
features, pos = self.backbone(samples)
|
43 |
+
|
44 |
+
num = self.args.layer1_num
|
45 |
+
src, mask = features[num].decompose()
|
46 |
+
assert mask is not None
|
47 |
+
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[num])[0]
|
48 |
+
|
49 |
+
outputs_class = self.class_embed(hs)
|
50 |
+
outputs_coord = self.lines_embed(hs).sigmoid()
|
51 |
+
out = {'pred_logits': outputs_class[-1], 'pred_lines': outputs_coord[-1]}
|
52 |
+
if self.aux_loss:
|
53 |
+
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
|
54 |
+
return out
|
55 |
+
|
56 |
+
@torch.jit.unused
|
57 |
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
58 |
+
return [{'pred_logits': a, 'pred_lines': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
59 |
+
|
60 |
+
class SetCriterion(nn.Module):
|
61 |
+
|
62 |
+
def __init__(self, num_classes, weight_dict, eos_coef, losses, args, matcher=None):
|
63 |
+
|
64 |
+
super().__init__()
|
65 |
+
self.num_classes = num_classes
|
66 |
+
|
67 |
+
self.matcher = matcher
|
68 |
+
|
69 |
+
self.weight_dict = weight_dict
|
70 |
+
self.eos_coef = eos_coef
|
71 |
+
self.losses = losses
|
72 |
+
empty_weight = torch.ones(self.num_classes + 1)
|
73 |
+
empty_weight[-1] = self.eos_coef
|
74 |
+
self.register_buffer('empty_weight', empty_weight)
|
75 |
+
self.args = args
|
76 |
+
try:
|
77 |
+
self.args.label_loss_params = eval(self.args.label_loss_params) # Convert the string to dict.
|
78 |
+
except:
|
79 |
+
pass
|
80 |
+
|
81 |
+
def loss_lines_labels(self, outputs, targets, num_items, log=False, origin_indices=None):
|
82 |
+
"""Classification loss (NLL)
|
83 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_lines]
|
84 |
+
"""
|
85 |
+
assert 'pred_logits' in outputs
|
86 |
+
src_logits = outputs['pred_logits']
|
87 |
+
|
88 |
+
idx = self._get_src_permutation_idx(origin_indices)
|
89 |
+
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, origin_indices)])
|
90 |
+
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
|
91 |
+
dtype=torch.int64, device=src_logits.device)
|
92 |
+
target_classes[idx] = target_classes_o
|
93 |
+
|
94 |
+
if self.args.label_loss_func == 'cross_entropy':
|
95 |
+
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
96 |
+
elif self.args.label_loss_func == 'focal_loss':
|
97 |
+
loss_ce = self.label_focal_loss(src_logits.transpose(1, 2), target_classes, self.empty_weight, **self.args.label_loss_params)
|
98 |
+
else:
|
99 |
+
raise ValueError()
|
100 |
+
|
101 |
+
losses = {'loss_ce': loss_ce}
|
102 |
+
return losses
|
103 |
+
|
104 |
+
def label_focal_loss(self, input, target, weight, gamma=2.0):
|
105 |
+
""" Focal loss for label prediction. """
|
106 |
+
# In our case, target has 2 classes: 0 for foreground (i.e. line) and 1 for background.
|
107 |
+
# The weight here can serve as the alpha hyperparameter in focal loss. However, in focal loss,
|
108 |
+
#
|
109 |
+
# Ref: https://github.com/facebookresearch/DETR/blob/699bf53f3e3ecd4f000007b8473eda6a08a8bed6/models/segmentation.py#L190
|
110 |
+
# Ref: https://medium.com/visionwizard/understanding-focal-loss-a-quick-read-b914422913e7
|
111 |
+
|
112 |
+
# input shape: [batch size, #classes, #queries]
|
113 |
+
# target shape: [batch size, #queries]
|
114 |
+
# weight shape: [#classes]
|
115 |
+
|
116 |
+
prob = F.softmax(input, 1) # Shape: [batch size, #classes, #queries].
|
117 |
+
ce_loss = F.cross_entropy(input, target, weight, reduction='none') # Shape: [batch size, #queries].
|
118 |
+
p_t = prob[:,1,:] * target + prob[:,0,:] * (1 - target) # Shape: [batch size, #queries]. Note: prob[:,0,:] + prob[:,1,:] should be 1.
|
119 |
+
loss = ce_loss * ((1 - p_t) ** gamma)
|
120 |
+
loss = loss.mean() # Original label loss (i.e. cross entropy) does not consider the #lines, so we also do not consider that.
|
121 |
+
return loss
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def loss_cardinality(self, outputs, targets, num_items, origin_indices=None):
|
125 |
+
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty lines
|
126 |
+
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
|
127 |
+
"""
|
128 |
+
pred_logits = outputs['pred_logits']
|
129 |
+
device = pred_logits.device
|
130 |
+
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
|
131 |
+
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
132 |
+
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
|
133 |
+
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
134 |
+
losses = {'cardinality_error': card_err}
|
135 |
+
return losses
|
136 |
+
|
137 |
+
def loss_lines_POST(self, outputs, targets, num_items, origin_indices=None):
|
138 |
+
assert 'POST_pred_lines' in outputs
|
139 |
+
|
140 |
+
if outputs['POST_pred_lines'].shape[1] == 1000:
|
141 |
+
idx = self._get_src_permutation_idx(origin_indices)
|
142 |
+
|
143 |
+
src_lines = outputs['POST_pred_lines'][idx]
|
144 |
+
|
145 |
+
else:
|
146 |
+
src_lines = outputs['POST_pred_lines'].squeeze(0)
|
147 |
+
|
148 |
+
|
149 |
+
target_lines = torch.cat([t['lines'][i] for t, (_, i) in zip(targets, origin_indices)], dim=0)
|
150 |
+
|
151 |
+
loss_line = F.l1_loss(src_lines, target_lines, reduction='none')
|
152 |
+
|
153 |
+
losses = {}
|
154 |
+
losses['loss_line'] = loss_line.sum() / num_items
|
155 |
+
|
156 |
+
return losses
|
157 |
+
|
158 |
+
def loss_lines(self, outputs, targets, num_items, origin_indices=None):
|
159 |
+
assert 'pred_lines' in outputs
|
160 |
+
|
161 |
+
idx = self._get_src_permutation_idx(origin_indices)
|
162 |
+
|
163 |
+
src_lines = outputs['pred_lines'][idx]
|
164 |
+
target_lines = torch.cat([t['lines'][i] for t, (_, i) in zip(targets, origin_indices)], dim=0)
|
165 |
+
|
166 |
+
loss_line = F.l1_loss(src_lines, target_lines, reduction='none')
|
167 |
+
|
168 |
+
losses = {}
|
169 |
+
losses['loss_line'] = loss_line.sum() / num_items
|
170 |
+
|
171 |
+
return losses
|
172 |
+
|
173 |
+
def _get_src_permutation_idx(self, indices):
|
174 |
+
# permute predictions following indices
|
175 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
176 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
177 |
+
return batch_idx, src_idx
|
178 |
+
|
179 |
+
def _get_tgt_permutation_idx(self, indices):
|
180 |
+
# permute targets following indices
|
181 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
182 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
183 |
+
return batch_idx, tgt_idx
|
184 |
+
|
185 |
+
def get_loss(self, loss, outputs, targets, num_items, **kwargs):
|
186 |
+
|
187 |
+
loss_map = {
|
188 |
+
'POST_lines_labels': self.loss_lines_labels,
|
189 |
+
'POST_lines': self.loss_lines,
|
190 |
+
'lines_labels': self.loss_lines_labels,
|
191 |
+
'cardinality': self.loss_cardinality,
|
192 |
+
'lines': self.loss_lines,
|
193 |
+
}
|
194 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
195 |
+
return loss_map[loss](outputs, targets, num_items, **kwargs)
|
196 |
+
|
197 |
+
def forward(self, outputs, targets, origin_indices=None):
|
198 |
+
""" This performs the loss computation.
|
199 |
+
Parameters:
|
200 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
201 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
202 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
203 |
+
"""
|
204 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
|
205 |
+
|
206 |
+
|
207 |
+
origin_indices = self.matcher(outputs_without_aux, targets)
|
208 |
+
|
209 |
+
|
210 |
+
num_items = sum(len(t["labels"]) for t in targets)
|
211 |
+
|
212 |
+
num_items = torch.as_tensor([num_items], dtype=torch.float, device=next(iter(outputs.values())).device)
|
213 |
+
if is_dist_avail_and_initialized():
|
214 |
+
torch.distributed.all_reduce(num_items)
|
215 |
+
num_items = torch.clamp(num_items / get_world_size(), min=1).item()
|
216 |
+
|
217 |
+
# Compute all the requested losses
|
218 |
+
losses = {}
|
219 |
+
for loss in self.losses:
|
220 |
+
losses.update(self.get_loss(loss, outputs, targets, num_items, origin_indices=origin_indices))
|
221 |
+
|
222 |
+
|
223 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
224 |
+
aux_name = 'aux_outputs'
|
225 |
+
if aux_name in outputs:
|
226 |
+
for i, aux_outputs in enumerate(outputs[aux_name]):
|
227 |
+
|
228 |
+
origin_indices = self.matcher(aux_outputs, targets)
|
229 |
+
|
230 |
+
for loss in self.losses:
|
231 |
+
|
232 |
+
kwargs = {}
|
233 |
+
if loss == 'labels':
|
234 |
+
# Logging is enabled only for the last layer
|
235 |
+
kwargs = {'log': False}
|
236 |
+
|
237 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, num_items, origin_indices=origin_indices, **kwargs)
|
238 |
+
|
239 |
+
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
240 |
+
losses.update(l_dict)
|
241 |
+
|
242 |
+
return losses
|
243 |
+
|
244 |
+
|
245 |
+
class PostProcess_Line(nn.Module):
|
246 |
+
|
247 |
+
""" This module converts the model's output into the format expected by the coco api"""
|
248 |
+
@torch.no_grad()
|
249 |
+
def forward(self, outputs, target_sizes, output_type):
|
250 |
+
""" Perform the computation
|
251 |
+
Parameters:
|
252 |
+
outputs: raw outputs of the model
|
253 |
+
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
254 |
+
For evaluation, this must be the original image size (before any data augmentation)
|
255 |
+
For visualization, this should be the image size after data augment, but before padding
|
256 |
+
"""
|
257 |
+
if output_type == "prediction":
|
258 |
+
out_logits, out_line = outputs['pred_logits'], outputs['pred_lines']
|
259 |
+
|
260 |
+
assert len(out_logits) == len(target_sizes)
|
261 |
+
assert target_sizes.shape[1] == 2
|
262 |
+
|
263 |
+
prob = F.softmax(out_logits, -1)
|
264 |
+
scores, labels = prob[..., :-1].max(-1)
|
265 |
+
|
266 |
+
# convert to [x0, y0, x1, y1] format
|
267 |
+
img_h, img_w = target_sizes.unbind(1)
|
268 |
+
|
269 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
270 |
+
lines = out_line * scale_fct[:, None, :]
|
271 |
+
|
272 |
+
results = [{'scores': s, 'labels': l, 'lines': b} for s, l, b in zip(scores, labels, lines)]
|
273 |
+
elif output_type == "prediction_POST":
|
274 |
+
out_logits, out_line = outputs['pred_logits'], outputs['POST_pred_lines']
|
275 |
+
|
276 |
+
assert len(out_logits) == len(target_sizes)
|
277 |
+
assert target_sizes.shape[1] == 2
|
278 |
+
|
279 |
+
prob = F.softmax(out_logits, -1)
|
280 |
+
scores, labels = prob[..., :-1].max(-1)
|
281 |
+
|
282 |
+
# convert to [x0, y0, x1, y1] format
|
283 |
+
img_h, img_w = target_sizes.unbind(1)
|
284 |
+
|
285 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
286 |
+
lines = out_line * scale_fct[:, None, :]
|
287 |
+
|
288 |
+
results = [{'scores': s, 'labels': l, 'lines': b} for s, l, b in zip(scores, labels, lines)]
|
289 |
+
elif output_type == "ground_truth":
|
290 |
+
results = []
|
291 |
+
for dic in outputs:
|
292 |
+
lines = dic['lines']
|
293 |
+
img_h, img_w = target_sizes.unbind(1)
|
294 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
295 |
+
scaled_lines = lines * scale_fct
|
296 |
+
results.append({'labels': dic['labels'], 'lines': scaled_lines, 'image_id': dic['image_id']})
|
297 |
+
else:
|
298 |
+
assert False
|
299 |
+
return results
|
300 |
+
|
301 |
+
|
302 |
+
class MLP(nn.Module):
|
303 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
304 |
+
|
305 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
306 |
+
super().__init__()
|
307 |
+
self.num_layers = num_layers
|
308 |
+
h = [hidden_dim] * (num_layers - 1)
|
309 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
310 |
+
|
311 |
+
def forward(self, x):
|
312 |
+
for i, layer in enumerate(self.layers):
|
313 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
314 |
+
return x
|
315 |
+
|
316 |
+
|
317 |
+
def build(args):
|
318 |
+
num_classes = 1
|
319 |
+
|
320 |
+
device = torch.device(args.device)
|
321 |
+
|
322 |
+
backbone = build_backbone(args)
|
323 |
+
|
324 |
+
transformer = build_transformer(args)
|
325 |
+
|
326 |
+
model = LETR(
|
327 |
+
backbone,
|
328 |
+
transformer,
|
329 |
+
num_classes=num_classes,
|
330 |
+
num_queries=args.num_queries,
|
331 |
+
args=args,
|
332 |
+
aux_loss=args.aux_loss,
|
333 |
+
)
|
334 |
+
|
335 |
+
if args.LETRpost:
|
336 |
+
model = LETRstack(model, args=args)
|
337 |
+
|
338 |
+
|
339 |
+
matcher = build_matcher(args, type='origin_line')
|
340 |
+
|
341 |
+
losses = []
|
342 |
+
weight_dict = {}
|
343 |
+
|
344 |
+
if args.LETRpost:
|
345 |
+
losses.append('POST_lines_labels')
|
346 |
+
losses.append('POST_lines')
|
347 |
+
weight_dict['loss_ce'] = 1
|
348 |
+
weight_dict['loss_line'] = args.line_loss_coef
|
349 |
+
aux_layer = args.second_dec_layers
|
350 |
+
else:
|
351 |
+
losses.append('lines_labels')
|
352 |
+
losses.append('lines')
|
353 |
+
weight_dict['loss_ce'] = 1
|
354 |
+
weight_dict['loss_line'] = args.line_loss_coef
|
355 |
+
aux_layer = args.dec_layers
|
356 |
+
|
357 |
+
if args.aux_loss:
|
358 |
+
aux_weight_dict = {}
|
359 |
+
for i in range(aux_layer - 1):
|
360 |
+
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
|
361 |
+
weight_dict.update(aux_weight_dict)
|
362 |
+
|
363 |
+
|
364 |
+
criterion = SetCriterion(num_classes, weight_dict=weight_dict, eos_coef=args.eos_coef, losses=losses, args=args, matcher=matcher)
|
365 |
+
criterion.to(device)
|
366 |
+
|
367 |
+
|
368 |
+
postprocessors = {'line': PostProcess_Line()}
|
369 |
+
|
370 |
+
|
371 |
+
return model, criterion, postprocessors
|
models/letr_stack.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file provides fine stage LETR definition
|
3 |
+
|
4 |
+
"""
|
5 |
+
import io
|
6 |
+
from collections import defaultdict
|
7 |
+
from typing import List, Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch import Tensor
|
13 |
+
from PIL import Image
|
14 |
+
from .misc import NestedTensor, nested_tensor_from_tensor_list
|
15 |
+
import copy
|
16 |
+
|
17 |
+
|
18 |
+
class LETRstack(nn.Module):
|
19 |
+
def __init__(self, letr, args):
|
20 |
+
super().__init__()
|
21 |
+
self.letr = letr
|
22 |
+
self.backbone = self.letr.backbone
|
23 |
+
|
24 |
+
if args.layer1_frozen:
|
25 |
+
# freeze backbone, encoder, decoder
|
26 |
+
for n, p in self.named_parameters():
|
27 |
+
p.requires_grad_(False)
|
28 |
+
|
29 |
+
hidden_dim, nheads = letr.transformer.d_model, letr.transformer.nhead
|
30 |
+
|
31 |
+
# add new input proj layer
|
32 |
+
channel = [256, 512, 1024, 2048]
|
33 |
+
self.input_proj = nn.Conv2d(channel[args.layer2_num], hidden_dim, kernel_size=1)
|
34 |
+
|
35 |
+
# add new transformer encoder decoder
|
36 |
+
self.transformer = Transformer( d_model=args.second_hidden_dim, dropout=args.second_dropout, nhead=args.second_nheads,
|
37 |
+
dim_feedforward=args.second_dim_feedforward, num_encoder_layers=args.second_enc_layers,
|
38 |
+
num_decoder_layers=args.second_dec_layers, normalize_before=args.second_pre_norm, return_intermediate_dec=True,)
|
39 |
+
|
40 |
+
# output layer
|
41 |
+
self.class_embed = nn.Linear(hidden_dim, 1 + 1)
|
42 |
+
self.lines_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
43 |
+
|
44 |
+
|
45 |
+
self.aux_loss=args.aux_loss
|
46 |
+
self.args = args
|
47 |
+
|
48 |
+
def forward(self, samples, postprocessors=None, targets=None, criterion=None):
|
49 |
+
if isinstance(samples, (list, torch.Tensor)):
|
50 |
+
samples = nested_tensor_from_tensor_list(samples)
|
51 |
+
|
52 |
+
# backbone
|
53 |
+
features, pos = self.letr.backbone(samples)
|
54 |
+
|
55 |
+
# layer 1
|
56 |
+
l1_num = self.args.layer1_num
|
57 |
+
src1, mask1 = features[l1_num].decompose()
|
58 |
+
assert mask1 is not None
|
59 |
+
|
60 |
+
# layer 1 transformer
|
61 |
+
hs1, _ = self.letr.transformer(self.letr.input_proj(src1), mask1, self.letr.query_embed.weight, pos[l1_num])
|
62 |
+
|
63 |
+
# layer 2
|
64 |
+
l2_num = self.args.layer2_num
|
65 |
+
src2, mask2 = features[l2_num].decompose()
|
66 |
+
src2 = self.input_proj(src2)
|
67 |
+
|
68 |
+
# layer 2 transformer
|
69 |
+
hs2, memory, _ = self.transformer(src2, mask2, hs1[-1], pos[l2_num])
|
70 |
+
|
71 |
+
outputs_class = self.class_embed(hs2)
|
72 |
+
outputs_coord = self.lines_embed(hs2).sigmoid()
|
73 |
+
out = {}
|
74 |
+
out["pred_logits"] = outputs_class[-1]
|
75 |
+
out["pred_lines"] = outputs_coord[-1]
|
76 |
+
|
77 |
+
if self.aux_loss:
|
78 |
+
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
|
79 |
+
|
80 |
+
return out, None
|
81 |
+
|
82 |
+
@torch.jit.unused
|
83 |
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
84 |
+
# this is a workaround to make torchscript happy, as torchscript
|
85 |
+
# doesn't support dictionary with non-homogeneous values, such
|
86 |
+
# as a dict having both a Tensor and a list.
|
87 |
+
return [{'pred_logits': a, 'pred_lines': b}
|
88 |
+
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
89 |
+
|
90 |
+
@torch.jit.unused
|
91 |
+
def _set_aux_loss_POST(self, outputs_class, outputs_coord):
|
92 |
+
# this is a workaround to make torchscript happy, as torchscript
|
93 |
+
# doesn't support dictionary with non-homogeneous values, such
|
94 |
+
# as a dict having both a Tensor and a list.
|
95 |
+
return [{'POST_pred_lines': b} for b in outputs_coord[:-1]]
|
96 |
+
|
97 |
+
def _expand(tensor, length: int):
|
98 |
+
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
99 |
+
|
100 |
+
class MLP(nn.Module):
|
101 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
102 |
+
|
103 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
104 |
+
super().__init__()
|
105 |
+
self.num_layers = num_layers
|
106 |
+
h = [hidden_dim] * (num_layers - 1)
|
107 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
for i, layer in enumerate(self.layers):
|
111 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
112 |
+
return x
|
113 |
+
|
114 |
+
|
115 |
+
class Transformer(nn.Module):
|
116 |
+
|
117 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
118 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
119 |
+
activation="relu", normalize_before=False,
|
120 |
+
return_intermediate_dec=False):
|
121 |
+
super().__init__()
|
122 |
+
|
123 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
|
124 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
125 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
126 |
+
|
127 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
|
128 |
+
decoder_norm = nn.LayerNorm(d_model)
|
129 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
130 |
+
return_intermediate=return_intermediate_dec)
|
131 |
+
|
132 |
+
self._reset_parameters()
|
133 |
+
|
134 |
+
self.d_model = d_model
|
135 |
+
self.nhead = nhead
|
136 |
+
|
137 |
+
def _reset_parameters(self):
|
138 |
+
for p in self.parameters():
|
139 |
+
if p.dim() > 1:
|
140 |
+
nn.init.xavier_uniform_(p)
|
141 |
+
|
142 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
143 |
+
# flatten NxCxHxW to HWxNxC
|
144 |
+
bs, c, h, w = src.shape
|
145 |
+
src = src.flatten(2).permute(2, 0, 1)
|
146 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
147 |
+
mask = mask.flatten(1)
|
148 |
+
|
149 |
+
query_embed = query_embed.permute(1, 0, 2)
|
150 |
+
tgt = torch.zeros_like(query_embed)
|
151 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
152 |
+
hs, attn_output_weights = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
|
153 |
+
return hs.transpose(1, 2), memory, attn_output_weights
|
154 |
+
|
155 |
+
class TransformerEncoder(nn.Module):
|
156 |
+
|
157 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
158 |
+
super().__init__()
|
159 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
160 |
+
self.num_layers = num_layers
|
161 |
+
self.norm = norm
|
162 |
+
|
163 |
+
def forward(self, src,
|
164 |
+
mask: Optional[Tensor] = None,
|
165 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
166 |
+
pos: Optional[Tensor] = None):
|
167 |
+
output = src
|
168 |
+
|
169 |
+
for layer in self.layers:
|
170 |
+
output = layer(output, src_mask=mask,
|
171 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
172 |
+
|
173 |
+
if self.norm is not None:
|
174 |
+
output = self.norm(output)
|
175 |
+
|
176 |
+
return output
|
177 |
+
|
178 |
+
class TransformerDecoder(nn.Module):
|
179 |
+
|
180 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
181 |
+
super().__init__()
|
182 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
183 |
+
self.num_layers = num_layers
|
184 |
+
self.norm = norm
|
185 |
+
self.return_intermediate = return_intermediate
|
186 |
+
|
187 |
+
def forward(self, tgt, memory,
|
188 |
+
tgt_mask: Optional[Tensor] = None,
|
189 |
+
memory_mask: Optional[Tensor] = None,
|
190 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
191 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
192 |
+
pos: Optional[Tensor] = None,
|
193 |
+
query_pos: Optional[Tensor] = None):
|
194 |
+
output = tgt
|
195 |
+
|
196 |
+
intermediate = []
|
197 |
+
attn_output_weights_list = []
|
198 |
+
for layer in self.layers:
|
199 |
+
output, attn_output_weights = layer(output, memory, tgt_mask=tgt_mask,
|
200 |
+
memory_mask=memory_mask,
|
201 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
202 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
203 |
+
pos=pos, query_pos=query_pos)
|
204 |
+
if self.return_intermediate:
|
205 |
+
intermediate.append(self.norm(output))
|
206 |
+
attn_output_weights_list.append(attn_output_weights)
|
207 |
+
if self.norm is not None:
|
208 |
+
output = self.norm(output)
|
209 |
+
if self.return_intermediate:
|
210 |
+
intermediate.pop()
|
211 |
+
intermediate.append(output)
|
212 |
+
|
213 |
+
if self.return_intermediate:
|
214 |
+
return torch.stack(intermediate), attn_output_weights_list
|
215 |
+
|
216 |
+
return output.unsqueeze(0), attn_output_weights
|
217 |
+
|
218 |
+
|
219 |
+
class TransformerEncoderLayer(nn.Module):
|
220 |
+
|
221 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
222 |
+
activation="relu", normalize_before=False):
|
223 |
+
super().__init__()
|
224 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
225 |
+
# Implementation of Feedforward model
|
226 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
227 |
+
self.dropout = nn.Dropout(dropout)
|
228 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
229 |
+
|
230 |
+
self.norm1 = nn.LayerNorm(d_model)
|
231 |
+
self.norm2 = nn.LayerNorm(d_model)
|
232 |
+
self.dropout1 = nn.Dropout(dropout)
|
233 |
+
self.dropout2 = nn.Dropout(dropout)
|
234 |
+
|
235 |
+
self.activation = _get_activation_fn(activation)
|
236 |
+
self.normalize_before = normalize_before
|
237 |
+
|
238 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
239 |
+
return tensor if pos is None else tensor + pos
|
240 |
+
|
241 |
+
def forward_post(self,
|
242 |
+
src,
|
243 |
+
src_mask: Optional[Tensor] = None,
|
244 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
245 |
+
pos: Optional[Tensor] = None):
|
246 |
+
q = k = self.with_pos_embed(src, pos)
|
247 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
248 |
+
key_padding_mask=src_key_padding_mask)[0]
|
249 |
+
src = src + self.dropout1(src2)
|
250 |
+
src = self.norm1(src)
|
251 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
252 |
+
src = src + self.dropout2(src2)
|
253 |
+
src = self.norm2(src)
|
254 |
+
return src
|
255 |
+
|
256 |
+
def forward_pre(self, src,
|
257 |
+
src_mask: Optional[Tensor] = None,
|
258 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
259 |
+
pos: Optional[Tensor] = None):
|
260 |
+
src2 = self.norm1(src)
|
261 |
+
q = k = self.with_pos_embed(src2, pos)
|
262 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
263 |
+
key_padding_mask=src_key_padding_mask)[0]
|
264 |
+
src = src + self.dropout1(src2)
|
265 |
+
src2 = self.norm2(src)
|
266 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
267 |
+
src = src + self.dropout2(src2)
|
268 |
+
return src
|
269 |
+
|
270 |
+
def forward(self, src,
|
271 |
+
src_mask: Optional[Tensor] = None,
|
272 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
273 |
+
pos: Optional[Tensor] = None):
|
274 |
+
if self.normalize_before:
|
275 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
276 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
277 |
+
|
278 |
+
|
279 |
+
class TransformerDecoderLayer(nn.Module):
|
280 |
+
|
281 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
282 |
+
activation="relu", normalize_before=False):
|
283 |
+
super().__init__()
|
284 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
285 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
286 |
+
# Implementation of Feedforward model
|
287 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
288 |
+
self.dropout = nn.Dropout(dropout)
|
289 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
290 |
+
|
291 |
+
self.norm1 = nn.LayerNorm(d_model)
|
292 |
+
self.norm2 = nn.LayerNorm(d_model)
|
293 |
+
self.norm3 = nn.LayerNorm(d_model)
|
294 |
+
self.dropout1 = nn.Dropout(dropout)
|
295 |
+
self.dropout2 = nn.Dropout(dropout)
|
296 |
+
self.dropout3 = nn.Dropout(dropout)
|
297 |
+
|
298 |
+
self.activation = _get_activation_fn(activation)
|
299 |
+
self.normalize_before = normalize_before
|
300 |
+
|
301 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
302 |
+
return tensor if pos is None else tensor + pos
|
303 |
+
|
304 |
+
def forward_post(self, tgt, memory,
|
305 |
+
tgt_mask: Optional[Tensor] = None,
|
306 |
+
memory_mask: Optional[Tensor] = None,
|
307 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
308 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
309 |
+
pos: Optional[Tensor] = None,
|
310 |
+
query_pos: Optional[Tensor] = None):
|
311 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
312 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
313 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
314 |
+
tgt = tgt + self.dropout1(tgt2)
|
315 |
+
tgt = self.norm1(tgt)
|
316 |
+
tgt2, attn_output_weights = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
317 |
+
key=self.with_pos_embed(memory, pos),
|
318 |
+
value=memory, attn_mask=memory_mask,
|
319 |
+
key_padding_mask=memory_key_padding_mask)
|
320 |
+
tgt = tgt + self.dropout2(tgt2)
|
321 |
+
tgt = self.norm2(tgt)
|
322 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
323 |
+
tgt = tgt + self.dropout3(tgt2)
|
324 |
+
tgt = self.norm3(tgt)
|
325 |
+
return tgt, attn_output_weights
|
326 |
+
|
327 |
+
def forward_pre(self, tgt, memory,
|
328 |
+
tgt_mask: Optional[Tensor] = None,
|
329 |
+
memory_mask: Optional[Tensor] = None,
|
330 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
331 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
332 |
+
pos: Optional[Tensor] = None,
|
333 |
+
query_pos: Optional[Tensor] = None):
|
334 |
+
tgt2 = self.norm1(tgt)
|
335 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
336 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
337 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
338 |
+
tgt = tgt + self.dropout1(tgt2)
|
339 |
+
tgt2 = self.norm2(tgt)
|
340 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
341 |
+
key=self.with_pos_embed(memory, pos),
|
342 |
+
value=memory, attn_mask=memory_mask,
|
343 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
344 |
+
tgt = tgt + self.dropout2(tgt2)
|
345 |
+
tgt2 = self.norm3(tgt)
|
346 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
347 |
+
tgt = tgt + self.dropout3(tgt2)
|
348 |
+
return tgt
|
349 |
+
|
350 |
+
def forward(self, tgt, memory,
|
351 |
+
tgt_mask: Optional[Tensor] = None,
|
352 |
+
memory_mask: Optional[Tensor] = None,
|
353 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
354 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
355 |
+
pos: Optional[Tensor] = None,
|
356 |
+
query_pos: Optional[Tensor] = None):
|
357 |
+
if self.normalize_before:
|
358 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
359 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
360 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
361 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
362 |
+
|
363 |
+
|
364 |
+
def _get_clones(module, N):
|
365 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
366 |
+
|
367 |
+
def _get_activation_fn(activation):
|
368 |
+
"""Return an activation function given a string"""
|
369 |
+
if activation == "relu":
|
370 |
+
return F.relu
|
371 |
+
if activation == "gelu":
|
372 |
+
return F.gelu
|
373 |
+
if activation == "glu":
|
374 |
+
return F.glu
|
375 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
376 |
+
|
models/matcher.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modules to compute the matching cost and solve the corresponding LSAP.
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
from scipy.optimize import linear_sum_assignment
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
class HungarianMatcher_Line(nn.Module):
|
9 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
10 |
+
|
11 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
12 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
13 |
+
while the others are un-matched (and thus treated as non-objects).
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, cost_class: float = 1, cost_line: float = 1):
|
17 |
+
"""Creates the matcher
|
18 |
+
|
19 |
+
Params:
|
20 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
21 |
+
cost_line: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
22 |
+
"""
|
23 |
+
super().__init__()
|
24 |
+
self.cost_class = cost_class
|
25 |
+
self.cost_line = cost_line
|
26 |
+
assert cost_class != 0 or cost_line != 0, "all costs cant be 0"
|
27 |
+
|
28 |
+
@torch.no_grad()
|
29 |
+
def forward(self, outputs, targets):
|
30 |
+
""" Performs the matching
|
31 |
+
|
32 |
+
Params:
|
33 |
+
outputs: This is a dict that contains at least these entries:
|
34 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
35 |
+
"pred_lines": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
36 |
+
|
37 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
38 |
+
"labels": Tensor of dim [num_target_lines] (where num_target_lines is the number of ground-truth
|
39 |
+
objects in the target) containing the class labels
|
40 |
+
"lines": Tensor of dim [num_target_lines, 4] containing the target box coordinates
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
44 |
+
- index_i is the indices of the selected predictions (in order)
|
45 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
46 |
+
For each batch element, it holds:
|
47 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_lines)
|
48 |
+
"""
|
49 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
50 |
+
|
51 |
+
# We flatten to compute the cost matrices in a batch
|
52 |
+
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
53 |
+
|
54 |
+
out_line = outputs["pred_lines"].flatten(0, 1) # [batch_size * num_queries, 4]
|
55 |
+
tgt_line = torch.cat([v["lines"] for v in targets])
|
56 |
+
|
57 |
+
|
58 |
+
# Also concat the target labels and lines
|
59 |
+
tgt_ids = torch.cat([v["labels"] for v in targets])
|
60 |
+
|
61 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
62 |
+
# but approximate it in 1 - proba[target class].
|
63 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
64 |
+
cost_class = -out_prob[:, tgt_ids]
|
65 |
+
|
66 |
+
# Compute the L1 cost between lines
|
67 |
+
cost_line = torch.cdist(out_line, tgt_line, p=1)
|
68 |
+
|
69 |
+
# Final cost matrix
|
70 |
+
C = self.cost_line * cost_line + self.cost_class * cost_class
|
71 |
+
C = C.view(bs, num_queries, -1).cpu()
|
72 |
+
|
73 |
+
sizes = [len(v["lines"]) for v in targets]
|
74 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
75 |
+
|
76 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
def build_matcher(args, type=None):
|
81 |
+
return HungarianMatcher_Line(cost_class=args.set_cost_class, cost_line=args.set_cost_line)
|
models/misc.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Misc functions, including distributed helpers.
|
4 |
+
|
5 |
+
Mostly copy-paste from torchvision references.
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
import subprocess
|
9 |
+
import time
|
10 |
+
from collections import defaultdict, deque
|
11 |
+
import datetime
|
12 |
+
import pickle
|
13 |
+
from typing import Optional, List
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.distributed as dist
|
17 |
+
from torch import Tensor
|
18 |
+
|
19 |
+
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
20 |
+
import torchvision
|
21 |
+
if float(torchvision.__version__[:3]) < 0.7:
|
22 |
+
from torchvision.ops import _new_empty_tensor
|
23 |
+
from torchvision.ops.misc import _output_size
|
24 |
+
|
25 |
+
|
26 |
+
class SmoothedValue(object):
|
27 |
+
"""Track a series of values and provide access to smoothed values over a
|
28 |
+
window or the global series average.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, window_size=20, fmt=None):
|
32 |
+
if fmt is None:
|
33 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
34 |
+
self.deque = deque(maxlen=window_size)
|
35 |
+
self.total = 0.0
|
36 |
+
self.count = 0
|
37 |
+
self.fmt = fmt
|
38 |
+
|
39 |
+
def update(self, value, n=1):
|
40 |
+
self.deque.append(value)
|
41 |
+
self.count += n
|
42 |
+
self.total += value * n
|
43 |
+
|
44 |
+
def synchronize_between_processes(self):
|
45 |
+
"""
|
46 |
+
Warning: does not synchronize the deque!
|
47 |
+
"""
|
48 |
+
if not is_dist_avail_and_initialized():
|
49 |
+
return
|
50 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
51 |
+
dist.barrier()
|
52 |
+
dist.all_reduce(t)
|
53 |
+
t = t.tolist()
|
54 |
+
self.count = int(t[0])
|
55 |
+
self.total = t[1]
|
56 |
+
|
57 |
+
@property
|
58 |
+
def median(self):
|
59 |
+
d = torch.tensor(list(self.deque))
|
60 |
+
return d.median().item()
|
61 |
+
|
62 |
+
@property
|
63 |
+
def avg(self):
|
64 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
65 |
+
return d.mean().item()
|
66 |
+
|
67 |
+
@property
|
68 |
+
def global_avg(self):
|
69 |
+
return self.total / self.count
|
70 |
+
|
71 |
+
@property
|
72 |
+
def max(self):
|
73 |
+
return max(self.deque)
|
74 |
+
|
75 |
+
@property
|
76 |
+
def value(self):
|
77 |
+
return self.deque[-1]
|
78 |
+
|
79 |
+
def __str__(self):
|
80 |
+
return self.fmt.format(
|
81 |
+
median=self.median,
|
82 |
+
avg=self.avg,
|
83 |
+
global_avg=self.global_avg,
|
84 |
+
max=self.max,
|
85 |
+
value=self.value)
|
86 |
+
|
87 |
+
|
88 |
+
def all_gather(data):
|
89 |
+
"""
|
90 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
91 |
+
Args:
|
92 |
+
data: any picklable object
|
93 |
+
Returns:
|
94 |
+
list[data]: list of data gathered from each rank
|
95 |
+
"""
|
96 |
+
world_size = get_world_size()
|
97 |
+
if world_size == 1:
|
98 |
+
return [data]
|
99 |
+
|
100 |
+
# serialized to a Tensor
|
101 |
+
buffer = pickle.dumps(data)
|
102 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
103 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
104 |
+
|
105 |
+
# obtain Tensor size of each rank
|
106 |
+
local_size = torch.tensor([tensor.numel()], device="cuda")
|
107 |
+
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
108 |
+
dist.all_gather(size_list, local_size)
|
109 |
+
size_list = [int(size.item()) for size in size_list]
|
110 |
+
max_size = max(size_list)
|
111 |
+
|
112 |
+
# receiving Tensor from all ranks
|
113 |
+
# we pad the tensor because torch all_gather does not support
|
114 |
+
# gathering tensors of different shapes
|
115 |
+
tensor_list = []
|
116 |
+
for _ in size_list:
|
117 |
+
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
118 |
+
if local_size != max_size:
|
119 |
+
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
120 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
121 |
+
dist.all_gather(tensor_list, tensor)
|
122 |
+
|
123 |
+
data_list = []
|
124 |
+
for size, tensor in zip(size_list, tensor_list):
|
125 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
126 |
+
data_list.append(pickle.loads(buffer))
|
127 |
+
|
128 |
+
return data_list
|
129 |
+
|
130 |
+
|
131 |
+
def reduce_dict(input_dict, average=True):
|
132 |
+
"""
|
133 |
+
Args:
|
134 |
+
input_dict (dict): all the values will be reduced
|
135 |
+
average (bool): whether to do average or sum
|
136 |
+
Reduce the values in the dictionary from all processes so that all processes
|
137 |
+
have the averaged results. Returns a dict with the same fields as
|
138 |
+
input_dict, after reduction.
|
139 |
+
"""
|
140 |
+
world_size = get_world_size()
|
141 |
+
if world_size < 2:
|
142 |
+
return input_dict
|
143 |
+
with torch.no_grad():
|
144 |
+
names = []
|
145 |
+
values = []
|
146 |
+
# sort the keys so that they are consistent across processes
|
147 |
+
for k in sorted(input_dict.keys()):
|
148 |
+
names.append(k)
|
149 |
+
values.append(input_dict[k])
|
150 |
+
values = torch.stack(values, dim=0)
|
151 |
+
dist.all_reduce(values)
|
152 |
+
if average:
|
153 |
+
values /= world_size
|
154 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
155 |
+
return reduced_dict
|
156 |
+
|
157 |
+
|
158 |
+
class MetricLogger(object):
|
159 |
+
def __init__(self, delimiter="\t"):
|
160 |
+
self.meters = defaultdict(SmoothedValue)
|
161 |
+
self.delimiter = delimiter
|
162 |
+
|
163 |
+
def update(self, **kwargs):
|
164 |
+
for k, v in kwargs.items():
|
165 |
+
if isinstance(v, torch.Tensor):
|
166 |
+
v = v.item()
|
167 |
+
assert isinstance(v, (float, int))
|
168 |
+
self.meters[k].update(v)
|
169 |
+
|
170 |
+
def __getattr__(self, attr):
|
171 |
+
if attr in self.meters:
|
172 |
+
return self.meters[attr]
|
173 |
+
if attr in self.__dict__:
|
174 |
+
return self.__dict__[attr]
|
175 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
176 |
+
type(self).__name__, attr))
|
177 |
+
|
178 |
+
def __str__(self):
|
179 |
+
loss_str = []
|
180 |
+
for name, meter in self.meters.items():
|
181 |
+
loss_str.append(
|
182 |
+
"{}: {}".format(name, str(meter))
|
183 |
+
)
|
184 |
+
return self.delimiter.join(loss_str)
|
185 |
+
|
186 |
+
def synchronize_between_processes(self):
|
187 |
+
for meter in self.meters.values():
|
188 |
+
meter.synchronize_between_processes()
|
189 |
+
|
190 |
+
def add_meter(self, name, meter):
|
191 |
+
self.meters[name] = meter
|
192 |
+
|
193 |
+
def log_every(self, iterable, print_freq, header=None):
|
194 |
+
i = 0
|
195 |
+
if not header:
|
196 |
+
header = ''
|
197 |
+
start_time = time.time()
|
198 |
+
end = time.time()
|
199 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
200 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
201 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
202 |
+
if torch.cuda.is_available():
|
203 |
+
log_msg = self.delimiter.join([
|
204 |
+
header,
|
205 |
+
'[{0' + space_fmt + '}/{1}]',
|
206 |
+
'eta: {eta}',
|
207 |
+
'{meters}',
|
208 |
+
'time: {time}',
|
209 |
+
'data: {data}',
|
210 |
+
'max mem: {memory:.0f}'
|
211 |
+
])
|
212 |
+
else:
|
213 |
+
log_msg = self.delimiter.join([
|
214 |
+
header,
|
215 |
+
'[{0' + space_fmt + '}/{1}]',
|
216 |
+
'eta: {eta}',
|
217 |
+
'{meters}',
|
218 |
+
'time: {time}',
|
219 |
+
'data: {data}'
|
220 |
+
])
|
221 |
+
MB = 1024.0 * 1024.0
|
222 |
+
for obj in iterable:
|
223 |
+
data_time.update(time.time() - end)
|
224 |
+
yield obj
|
225 |
+
iter_time.update(time.time() - end)
|
226 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
227 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
228 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
229 |
+
if torch.cuda.is_available():
|
230 |
+
print(log_msg.format(
|
231 |
+
i, len(iterable), eta=eta_string,
|
232 |
+
meters=str(self),
|
233 |
+
time=str(iter_time), data=str(data_time),
|
234 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
235 |
+
else:
|
236 |
+
print(log_msg.format(
|
237 |
+
i, len(iterable), eta=eta_string,
|
238 |
+
meters=str(self),
|
239 |
+
time=str(iter_time), data=str(data_time)))
|
240 |
+
i += 1
|
241 |
+
end = time.time()
|
242 |
+
total_time = time.time() - start_time
|
243 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
244 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
245 |
+
header, total_time_str, total_time / len(iterable)))
|
246 |
+
|
247 |
+
|
248 |
+
def get_sha():
|
249 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
250 |
+
|
251 |
+
def _run(command):
|
252 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
253 |
+
sha = 'N/A'
|
254 |
+
diff = "clean"
|
255 |
+
branch = 'N/A'
|
256 |
+
try:
|
257 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
258 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
259 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
260 |
+
diff = "has uncommited changes" if diff else "clean"
|
261 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
262 |
+
except Exception:
|
263 |
+
pass
|
264 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
265 |
+
return message
|
266 |
+
|
267 |
+
|
268 |
+
def collate_fn(batch):
|
269 |
+
batch = list(zip(*batch))
|
270 |
+
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
271 |
+
return tuple(batch)
|
272 |
+
|
273 |
+
|
274 |
+
def _max_by_axis(the_list):
|
275 |
+
# type: (List[List[int]]) -> List[int]
|
276 |
+
maxes = the_list[0]
|
277 |
+
for sublist in the_list[1:]:
|
278 |
+
for index, item in enumerate(sublist):
|
279 |
+
maxes[index] = max(maxes[index], item)
|
280 |
+
return maxes
|
281 |
+
|
282 |
+
|
283 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
284 |
+
# TODO make this more general
|
285 |
+
if tensor_list[0].ndim == 3:
|
286 |
+
if torchvision._is_tracing():
|
287 |
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
288 |
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
289 |
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
290 |
+
|
291 |
+
# TODO make it support different-sized images
|
292 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
293 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
294 |
+
batch_shape = [len(tensor_list)] + max_size
|
295 |
+
b, c, h, w = batch_shape
|
296 |
+
dtype = tensor_list[0].dtype
|
297 |
+
device = tensor_list[0].device
|
298 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
299 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
300 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
301 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
302 |
+
m[: img.shape[1], :img.shape[2]] = False
|
303 |
+
else:
|
304 |
+
raise ValueError('not supported')
|
305 |
+
return NestedTensor(tensor, mask)
|
306 |
+
|
307 |
+
|
308 |
+
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
309 |
+
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
310 |
+
@torch.jit.unused
|
311 |
+
def _onnx_nested_tensor_from_tensor_list(tensor_list):
|
312 |
+
max_size = []
|
313 |
+
for i in range(tensor_list[0].dim()):
|
314 |
+
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
315 |
+
max_size.append(max_size_i)
|
316 |
+
max_size = tuple(max_size)
|
317 |
+
|
318 |
+
# work around for
|
319 |
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
320 |
+
# m[: img.shape[1], :img.shape[2]] = False
|
321 |
+
# which is not yet supported in onnx
|
322 |
+
padded_imgs = []
|
323 |
+
padded_masks = []
|
324 |
+
for img in tensor_list:
|
325 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
326 |
+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
327 |
+
padded_imgs.append(padded_img)
|
328 |
+
|
329 |
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
330 |
+
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
331 |
+
padded_masks.append(padded_mask.to(torch.bool))
|
332 |
+
|
333 |
+
tensor = torch.stack(padded_imgs)
|
334 |
+
mask = torch.stack(padded_masks)
|
335 |
+
|
336 |
+
return NestedTensor(tensor, mask=mask)
|
337 |
+
|
338 |
+
|
339 |
+
class NestedTensor(object):
|
340 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
341 |
+
self.tensors = tensors
|
342 |
+
self.mask = mask
|
343 |
+
|
344 |
+
def to(self, device):
|
345 |
+
# type: (Device) -> NestedTensor # noqa
|
346 |
+
cast_tensor = self.tensors.to(device)
|
347 |
+
mask = self.mask
|
348 |
+
if mask is not None:
|
349 |
+
assert mask is not None
|
350 |
+
cast_mask = mask.to(device)
|
351 |
+
else:
|
352 |
+
cast_mask = None
|
353 |
+
return NestedTensor(cast_tensor, cast_mask)
|
354 |
+
|
355 |
+
def decompose(self):
|
356 |
+
return self.tensors, self.mask
|
357 |
+
|
358 |
+
def __repr__(self):
|
359 |
+
return str(self.tensors)
|
360 |
+
|
361 |
+
|
362 |
+
def setup_for_distributed(is_master):
|
363 |
+
"""
|
364 |
+
This function disables printing when not in master process
|
365 |
+
"""
|
366 |
+
import builtins as __builtin__
|
367 |
+
builtin_print = __builtin__.print
|
368 |
+
|
369 |
+
def print(*args, **kwargs):
|
370 |
+
force = kwargs.pop('force', False)
|
371 |
+
if is_master or force:
|
372 |
+
builtin_print(*args, **kwargs)
|
373 |
+
|
374 |
+
__builtin__.print = print
|
375 |
+
|
376 |
+
|
377 |
+
def is_dist_avail_and_initialized():
|
378 |
+
if not dist.is_available():
|
379 |
+
return False
|
380 |
+
if not dist.is_initialized():
|
381 |
+
return False
|
382 |
+
return True
|
383 |
+
|
384 |
+
|
385 |
+
def get_world_size():
|
386 |
+
if not is_dist_avail_and_initialized():
|
387 |
+
return 1
|
388 |
+
return dist.get_world_size()
|
389 |
+
|
390 |
+
|
391 |
+
def get_rank():
|
392 |
+
if not is_dist_avail_and_initialized():
|
393 |
+
return 0
|
394 |
+
return dist.get_rank()
|
395 |
+
|
396 |
+
|
397 |
+
def is_main_process():
|
398 |
+
return get_rank() == 0
|
399 |
+
|
400 |
+
|
401 |
+
def save_on_master(*args, **kwargs):
|
402 |
+
if is_main_process():
|
403 |
+
torch.save(*args, **kwargs)
|
404 |
+
|
405 |
+
|
406 |
+
def init_distributed_mode(args):
|
407 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
408 |
+
args.rank = int(os.environ["RANK"])
|
409 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
410 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
411 |
+
elif 'SLURM_PROCID' in os.environ:
|
412 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
413 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
414 |
+
else:
|
415 |
+
print('Not using distributed mode')
|
416 |
+
args.distributed = False
|
417 |
+
return
|
418 |
+
|
419 |
+
args.distributed = True
|
420 |
+
|
421 |
+
torch.cuda.set_device(args.gpu)
|
422 |
+
args.dist_backend = 'nccl'
|
423 |
+
print('| distributed init (rank {}): {}'.format(
|
424 |
+
args.rank, args.dist_url), flush=True)
|
425 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
426 |
+
world_size=args.world_size, rank=args.rank)
|
427 |
+
torch.distributed.barrier()
|
428 |
+
setup_for_distributed(args.rank == 0)
|
429 |
+
|
430 |
+
|
431 |
+
@torch.no_grad()
|
432 |
+
def accuracy(output, target, topk=(1,)):
|
433 |
+
"""Computes the precision@k for the specified values of k"""
|
434 |
+
if target.numel() == 0:
|
435 |
+
return [torch.zeros([], device=output.device)]
|
436 |
+
maxk = max(topk)
|
437 |
+
batch_size = target.size(0)
|
438 |
+
|
439 |
+
_, pred = output.topk(maxk, 1, True, True)
|
440 |
+
pred = pred.t()
|
441 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
442 |
+
|
443 |
+
res = []
|
444 |
+
for k in topk:
|
445 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
446 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
447 |
+
return res
|
448 |
+
|
449 |
+
|
450 |
+
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
451 |
+
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
452 |
+
"""
|
453 |
+
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
454 |
+
This will eventually be supported natively by PyTorch, and this
|
455 |
+
class can go away.
|
456 |
+
"""
|
457 |
+
if float(torchvision.__version__[:3]) < 0.7:
|
458 |
+
if input.numel() > 0:
|
459 |
+
return torch.nn.functional.interpolate(
|
460 |
+
input, size, scale_factor, mode, align_corners
|
461 |
+
)
|
462 |
+
|
463 |
+
output_shape = _output_size(2, input, size, scale_factor)
|
464 |
+
output_shape = list(input.shape[:-2]) + list(output_shape)
|
465 |
+
return _new_empty_tensor(input, output_shape)
|
466 |
+
else:
|
467 |
+
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
models/multi_head_attention.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file provides definition of multi head attention
|
3 |
+
|
4 |
+
borrowed from https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#MultiheadAttention
|
5 |
+
"""
|
6 |
+
import warnings
|
7 |
+
from typing import Tuple, Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.nn.modules.linear import _LinearWithBias
|
12 |
+
from torch.nn.init import xavier_uniform_
|
13 |
+
from torch.nn.init import constant_
|
14 |
+
from torch.nn.init import xavier_normal_
|
15 |
+
from torch.nn.parameter import Parameter
|
16 |
+
from torch.nn.modules.module import Module
|
17 |
+
from torch.nn import functional as F
|
18 |
+
from torch.overrides import has_torch_function, handle_torch_function
|
19 |
+
from torch import _VF
|
20 |
+
|
21 |
+
# Activation functions
|
22 |
+
def dropout(input, p=0.5, training=True, inplace=False):
|
23 |
+
# type: (Tensor, float, bool, bool) -> Tensor
|
24 |
+
r"""
|
25 |
+
During training, randomly zeroes some of the elements of the input
|
26 |
+
tensor with probability :attr:`p` using samples from a Bernoulli
|
27 |
+
distribution.
|
28 |
+
|
29 |
+
See :class:`~torch.nn.Dropout` for details.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
p: probability of an element to be zeroed. Default: 0.5
|
33 |
+
training: apply dropout if is ``True``. Default: ``True``
|
34 |
+
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
|
35 |
+
"""
|
36 |
+
if not torch.jit.is_scripting():
|
37 |
+
if type(input) is not Tensor and has_torch_function((input,)):
|
38 |
+
return handle_torch_function(
|
39 |
+
dropout, (input,), input, p=p, training=training, inplace=inplace)
|
40 |
+
if p < 0. or p > 1.:
|
41 |
+
raise ValueError("dropout probability has to be between 0 and 1, "
|
42 |
+
"but got {}".format(p))
|
43 |
+
return (_VF.dropout_(input, p, training)
|
44 |
+
if inplace
|
45 |
+
else _VF.dropout(input, p, training))
|
46 |
+
|
47 |
+
|
48 |
+
def _get_softmax_dim(name, ndim, stacklevel):
|
49 |
+
# type: (str, int, int) -> int
|
50 |
+
warnings.warn("Implicit dimension choice for {} has been deprecated. "
|
51 |
+
"Change the call to include dim=X as an argument.".format(name), stacklevel=stacklevel)
|
52 |
+
if ndim == 0 or ndim == 1 or ndim == 3:
|
53 |
+
ret = 0
|
54 |
+
else:
|
55 |
+
ret = 1
|
56 |
+
return ret
|
57 |
+
|
58 |
+
def softmax(input, dim=None, _stacklevel=3, dtype=None):
|
59 |
+
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
60 |
+
r"""Applies a softmax function.
|
61 |
+
Softmax is defined as:
|
62 |
+
:math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}`
|
63 |
+
It is applied to all slices along dim, and will re-scale them so that the elements
|
64 |
+
lie in the range `[0, 1]` and sum to 1.
|
65 |
+
See :class:`~torch.nn.Softmax` for more details.
|
66 |
+
Args:
|
67 |
+
input (Tensor): input
|
68 |
+
dim (int): A dimension along which softmax will be computed.
|
69 |
+
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
70 |
+
If specified, the input tensor is casted to :attr:`dtype` before the operation
|
71 |
+
is performed. This is useful for preventing data type overflows. Default: None.
|
72 |
+
.. note::
|
73 |
+
This function doesn't work directly with NLLLoss,
|
74 |
+
which expects the Log to be computed between the Softmax and itself.
|
75 |
+
Use log_softmax instead (it's faster and has better numerical properties).
|
76 |
+
"""
|
77 |
+
if not torch.jit.is_scripting():
|
78 |
+
if type(input) is not Tensor and has_torch_function((input,)):
|
79 |
+
return handle_torch_function(
|
80 |
+
softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
|
81 |
+
if dim is None:
|
82 |
+
dim = _get_softmax_dim('softmax', input.dim(), _stacklevel)
|
83 |
+
if dtype is None:
|
84 |
+
ret = input.softmax(dim)
|
85 |
+
else:
|
86 |
+
ret = input.softmax(dim, dtype=dtype)
|
87 |
+
return ret
|
88 |
+
|
89 |
+
|
90 |
+
def linear(input, weight, bias=None):
|
91 |
+
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
|
92 |
+
r"""
|
93 |
+
Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
|
94 |
+
This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
|
95 |
+
Shape:
|
96 |
+
- Input: :math:`(N, *, in\_features)` N is the batch size, `*` means any number of
|
97 |
+
additional dimensions
|
98 |
+
- Weight: :math:`(out\_features, in\_features)`
|
99 |
+
- Bias: :math:`(out\_features)`
|
100 |
+
- Output: :math:`(N, *, out\_features)`
|
101 |
+
"""
|
102 |
+
tens_ops = (input, weight)
|
103 |
+
if not torch.jit.is_scripting():
|
104 |
+
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
|
105 |
+
return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
|
106 |
+
if input.dim() == 2 and bias is not None:
|
107 |
+
# fused op is marginally faster
|
108 |
+
ret = torch.addmm(bias, input, weight.t())
|
109 |
+
else:
|
110 |
+
output = input.matmul(weight.t())
|
111 |
+
if bias is not None:
|
112 |
+
output += bias
|
113 |
+
ret = output
|
114 |
+
return ret
|
115 |
+
|
116 |
+
def multi_head_attention_forward(query: Tensor, key: Tensor, value: Tensor, embed_dim_to_check: int,num_heads: int,
|
117 |
+
in_proj_weight: Tensor, in_proj_bias: Tensor, bias_k: Optional[Tensor], bias_v: Optional[Tensor], add_zero_attn: bool,
|
118 |
+
dropout_p: float, out_proj_weight: Tensor, out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None,
|
119 |
+
need_weights: bool = True, attn_mask: Optional[Tensor] = None, use_separate_proj_weight: bool = False, q_proj_weight: Optional[Tensor] = None,
|
120 |
+
k_proj_weight: Optional[Tensor] = None, v_proj_weight: Optional[Tensor] = None, static_k: Optional[Tensor] = None, static_v: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
|
121 |
+
r"""
|
122 |
+
Args:
|
123 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
124 |
+
See "Attention Is All You Need" for more details.
|
125 |
+
embed_dim_to_check: total dimension of the model.
|
126 |
+
num_heads: parallel attention heads.
|
127 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
128 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
129 |
+
add_zero_attn: add a new batch of zeros to the key and
|
130 |
+
value sequences at dim=1.
|
131 |
+
dropout_p: probability of an element to be zeroed.
|
132 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
133 |
+
training: apply dropout if is ``True``.
|
134 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
135 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
136 |
+
the corresponding value on the attention layer will be filled with -inf.
|
137 |
+
need_weights: output attn_output_weights.
|
138 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
139 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
140 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
141 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
142 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
143 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
144 |
+
static_k, static_v: static key and value used for attention operators.
|
145 |
+
Shape:
|
146 |
+
Inputs:
|
147 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
148 |
+
the embedding dimension.
|
149 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
150 |
+
the embedding dimension.
|
151 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
152 |
+
the embedding dimension.
|
153 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
154 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
155 |
+
will be unchanged. If a BoolTensor is provided, the positions with the
|
156 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
157 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
158 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
159 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
160 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
161 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
162 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
163 |
+
is provided, it will be added to the attention weight.
|
164 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
165 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
166 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
167 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
168 |
+
Outputs:
|
169 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
170 |
+
E is the embedding dimension.
|
171 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
172 |
+
L is the target sequence length, S is the source sequence length.
|
173 |
+
"""
|
174 |
+
if not torch.jit.is_scripting():
|
175 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
|
176 |
+
out_proj_weight, out_proj_bias)
|
177 |
+
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
|
178 |
+
return handle_torch_function(
|
179 |
+
multi_head_attention_forward, tens_ops, query, key, value,
|
180 |
+
embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
|
181 |
+
bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
|
182 |
+
out_proj_bias, training=training, key_padding_mask=key_padding_mask,
|
183 |
+
need_weights=need_weights, attn_mask=attn_mask,
|
184 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
185 |
+
q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
|
186 |
+
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
|
187 |
+
tgt_len, bsz, embed_dim = query.size()
|
188 |
+
assert embed_dim == embed_dim_to_check
|
189 |
+
# allow MHA to have different sizes for the feature dimension
|
190 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
191 |
+
|
192 |
+
head_dim = embed_dim // num_heads
|
193 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
194 |
+
scaling = float(head_dim) ** -0.5
|
195 |
+
|
196 |
+
if not use_separate_proj_weight:
|
197 |
+
if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):
|
198 |
+
# self-attention
|
199 |
+
q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
200 |
+
|
201 |
+
elif (key is value or torch.equal(key, value)):
|
202 |
+
# encoder-decoder attention
|
203 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
204 |
+
_b = in_proj_bias
|
205 |
+
_start = 0
|
206 |
+
_end = embed_dim
|
207 |
+
_w = in_proj_weight[_start:_end, :]
|
208 |
+
if _b is not None:
|
209 |
+
_b = _b[_start:_end]
|
210 |
+
q = linear(query, _w, _b)
|
211 |
+
|
212 |
+
if key is None:
|
213 |
+
assert value is None
|
214 |
+
k = None
|
215 |
+
v = None
|
216 |
+
else:
|
217 |
+
|
218 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
219 |
+
_b = in_proj_bias
|
220 |
+
_start = embed_dim
|
221 |
+
_end = None
|
222 |
+
_w = in_proj_weight[_start:, :]
|
223 |
+
if _b is not None:
|
224 |
+
_b = _b[_start:]
|
225 |
+
k, v = linear(key, _w, _b).chunk(2, dim=-1)
|
226 |
+
|
227 |
+
else:
|
228 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
229 |
+
_b = in_proj_bias
|
230 |
+
_start = 0
|
231 |
+
_end = embed_dim
|
232 |
+
_w = in_proj_weight[_start:_end, :]
|
233 |
+
if _b is not None:
|
234 |
+
_b = _b[_start:_end]
|
235 |
+
q = linear(query, _w, _b)
|
236 |
+
|
237 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
238 |
+
_b = in_proj_bias
|
239 |
+
_start = embed_dim
|
240 |
+
_end = embed_dim * 2
|
241 |
+
_w = in_proj_weight[_start:_end, :]
|
242 |
+
if _b is not None:
|
243 |
+
_b = _b[_start:_end]
|
244 |
+
k = linear(key, _w, _b)
|
245 |
+
|
246 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
247 |
+
_b = in_proj_bias
|
248 |
+
_start = embed_dim * 2
|
249 |
+
_end = None
|
250 |
+
_w = in_proj_weight[_start:, :]
|
251 |
+
if _b is not None:
|
252 |
+
_b = _b[_start:]
|
253 |
+
v = linear(value, _w, _b)
|
254 |
+
else:
|
255 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
256 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
257 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
258 |
+
|
259 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
260 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
261 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
262 |
+
|
263 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
264 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
265 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
266 |
+
|
267 |
+
if in_proj_bias is not None:
|
268 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
269 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
270 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
271 |
+
else:
|
272 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias)
|
273 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias)
|
274 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias)
|
275 |
+
q = q * scaling
|
276 |
+
|
277 |
+
if attn_mask is not None:
|
278 |
+
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
279 |
+
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
280 |
+
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
281 |
+
if attn_mask.dtype == torch.uint8:
|
282 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
283 |
+
attn_mask = attn_mask.to(torch.bool)
|
284 |
+
|
285 |
+
if attn_mask.dim() == 2:
|
286 |
+
attn_mask = attn_mask.unsqueeze(0)
|
287 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
288 |
+
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
289 |
+
elif attn_mask.dim() == 3:
|
290 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
291 |
+
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
292 |
+
else:
|
293 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
294 |
+
# attn_mask's dim is 3 now.
|
295 |
+
|
296 |
+
# convert ByteTensor key_padding_mask to bool
|
297 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
298 |
+
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
299 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
300 |
+
|
301 |
+
if bias_k is not None and bias_v is not None:
|
302 |
+
if static_k is None and static_v is None:
|
303 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
304 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
305 |
+
if attn_mask is not None:
|
306 |
+
attn_mask = pad(attn_mask, (0, 1))
|
307 |
+
if key_padding_mask is not None:
|
308 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
309 |
+
else:
|
310 |
+
assert static_k is None, "bias cannot be added to static key."
|
311 |
+
assert static_v is None, "bias cannot be added to static value."
|
312 |
+
else:
|
313 |
+
assert bias_k is None
|
314 |
+
assert bias_v is None
|
315 |
+
|
316 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
317 |
+
if k is not None:
|
318 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
319 |
+
if v is not None:
|
320 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
321 |
+
|
322 |
+
if static_k is not None:
|
323 |
+
assert static_k.size(0) == bsz * num_heads
|
324 |
+
assert static_k.size(2) == head_dim
|
325 |
+
k = static_k
|
326 |
+
|
327 |
+
if static_v is not None:
|
328 |
+
assert static_v.size(0) == bsz * num_heads
|
329 |
+
assert static_v.size(2) == head_dim
|
330 |
+
v = static_v
|
331 |
+
|
332 |
+
src_len = k.size(1)
|
333 |
+
|
334 |
+
if key_padding_mask is not None:
|
335 |
+
assert key_padding_mask.size(0) == bsz
|
336 |
+
assert key_padding_mask.size(1) == src_len
|
337 |
+
|
338 |
+
if add_zero_attn:
|
339 |
+
src_len += 1
|
340 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
341 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
342 |
+
if attn_mask is not None:
|
343 |
+
attn_mask = pad(attn_mask, (0, 1))
|
344 |
+
if key_padding_mask is not None:
|
345 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
346 |
+
|
347 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
348 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
349 |
+
|
350 |
+
if attn_mask is not None:
|
351 |
+
if attn_mask.dtype == torch.bool:
|
352 |
+
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
353 |
+
else:
|
354 |
+
attn_output_weights += attn_mask
|
355 |
+
|
356 |
+
|
357 |
+
if key_padding_mask is not None:
|
358 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
359 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
360 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
361 |
+
float('-inf'),
|
362 |
+
)
|
363 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
364 |
+
|
365 |
+
attn_output_weights = softmax(
|
366 |
+
attn_output_weights, dim=-1)
|
367 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
|
368 |
+
|
369 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
370 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
371 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
372 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
373 |
+
|
374 |
+
if need_weights:
|
375 |
+
# average attention weights over heads
|
376 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
377 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
378 |
+
else:
|
379 |
+
return attn_output, None
|
380 |
+
|
381 |
+
class MultiheadAttention(Module):
|
382 |
+
r"""Allows the model to jointly attend to information
|
383 |
+
from different representation subspaces.
|
384 |
+
See reference: Attention Is All You Need
|
385 |
+
|
386 |
+
.. math::
|
387 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
388 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
389 |
+
|
390 |
+
Args:
|
391 |
+
embed_dim: total dimension of the model.
|
392 |
+
num_heads: parallel attention heads.
|
393 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
394 |
+
bias: add bias as module parameter. Default: True.
|
395 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
396 |
+
add_zero_attn: add a new batch of zeros to the key and
|
397 |
+
value sequences at dim=1.
|
398 |
+
kdim: total number of features in key. Default: None.
|
399 |
+
vdim: total number of features in value. Default: None.
|
400 |
+
|
401 |
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
402 |
+
query, key, and value have the same number of features.
|
403 |
+
|
404 |
+
Examples::
|
405 |
+
|
406 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
407 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
408 |
+
"""
|
409 |
+
bias_k: Optional[torch.Tensor]
|
410 |
+
bias_v: Optional[torch.Tensor]
|
411 |
+
|
412 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
413 |
+
super(MultiheadAttention, self).__init__()
|
414 |
+
self.embed_dim = embed_dim
|
415 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
416 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
417 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
418 |
+
|
419 |
+
self.num_heads = num_heads
|
420 |
+
self.dropout = dropout
|
421 |
+
self.head_dim = embed_dim // num_heads
|
422 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
423 |
+
|
424 |
+
if self._qkv_same_embed_dim is False:
|
425 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
426 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
427 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
428 |
+
self.register_parameter('in_proj_weight', None)
|
429 |
+
else:
|
430 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
431 |
+
self.register_parameter('q_proj_weight', None)
|
432 |
+
self.register_parameter('k_proj_weight', None)
|
433 |
+
self.register_parameter('v_proj_weight', None)
|
434 |
+
|
435 |
+
if bias:
|
436 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
437 |
+
else:
|
438 |
+
self.register_parameter('in_proj_bias', None)
|
439 |
+
self.out_proj = _LinearWithBias(embed_dim, embed_dim)
|
440 |
+
|
441 |
+
if add_bias_kv:
|
442 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
443 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
444 |
+
else:
|
445 |
+
self.bias_k = self.bias_v = None
|
446 |
+
|
447 |
+
self.add_zero_attn = add_zero_attn
|
448 |
+
|
449 |
+
self._reset_parameters()
|
450 |
+
|
451 |
+
def _reset_parameters(self):
|
452 |
+
if self._qkv_same_embed_dim:
|
453 |
+
xavier_uniform_(self.in_proj_weight)
|
454 |
+
else:
|
455 |
+
xavier_uniform_(self.q_proj_weight)
|
456 |
+
xavier_uniform_(self.k_proj_weight)
|
457 |
+
xavier_uniform_(self.v_proj_weight)
|
458 |
+
|
459 |
+
if self.in_proj_bias is not None:
|
460 |
+
constant_(self.in_proj_bias, 0.)
|
461 |
+
constant_(self.out_proj.bias, 0.)
|
462 |
+
if self.bias_k is not None:
|
463 |
+
xavier_normal_(self.bias_k)
|
464 |
+
if self.bias_v is not None:
|
465 |
+
xavier_normal_(self.bias_v)
|
466 |
+
|
467 |
+
def __setstate__(self, state):
|
468 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
469 |
+
if '_qkv_same_embed_dim' not in state:
|
470 |
+
state['_qkv_same_embed_dim'] = True
|
471 |
+
|
472 |
+
super(MultiheadAttention, self).__setstate__(state)
|
473 |
+
|
474 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
475 |
+
need_weights=True, attn_mask=None):
|
476 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
|
477 |
+
r"""
|
478 |
+
Args:
|
479 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
480 |
+
See "Attention Is All You Need" for more details.
|
481 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
482 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
483 |
+
the corresponding value on the attention layer will be ignored. When given
|
484 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
485 |
+
layer will be ignored
|
486 |
+
need_weights: output attn_output_weights.
|
487 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
488 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
489 |
+
|
490 |
+
Shape:
|
491 |
+
- Inputs:
|
492 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
493 |
+
the embedding dimension.
|
494 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
495 |
+
the embedding dimension.
|
496 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
497 |
+
the embedding dimension.
|
498 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
499 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
500 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
501 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
502 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
503 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
504 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
505 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
506 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
507 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
508 |
+
is provided, it will be added to the attention weight.
|
509 |
+
|
510 |
+
- Outputs:
|
511 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
512 |
+
E is the embedding dimension.
|
513 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
514 |
+
L is the target sequence length, S is the source sequence length.
|
515 |
+
"""
|
516 |
+
if not self._qkv_same_embed_dim:
|
517 |
+
return multi_head_attention_forward(
|
518 |
+
query, key, value, self.embed_dim, self.num_heads,
|
519 |
+
self.in_proj_weight, self.in_proj_bias,
|
520 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
521 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
522 |
+
training=self.training,
|
523 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
524 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
525 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
526 |
+
v_proj_weight=self.v_proj_weight)
|
527 |
+
else:
|
528 |
+
return multi_head_attention_forward(
|
529 |
+
query, key, value, self.embed_dim, self.num_heads,
|
530 |
+
self.in_proj_weight, self.in_proj_bias,
|
531 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
532 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
533 |
+
training=self.training,
|
534 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
535 |
+
attn_mask=attn_mask)
|
536 |
+
|
537 |
+
|
models/position_encoding.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Various positional encodings for the transformer.
|
3 |
+
borrowed from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from .misc import NestedTensor
|
10 |
+
|
11 |
+
|
12 |
+
class PositionEmbeddingSine(nn.Module):
|
13 |
+
"""
|
14 |
+
This is a more standard version of the position embedding, very similar to the one
|
15 |
+
used by the Attention is all you need paper, generalized to work on images.
|
16 |
+
"""
|
17 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
18 |
+
super().__init__()
|
19 |
+
self.num_pos_feats = num_pos_feats
|
20 |
+
self.temperature = temperature
|
21 |
+
self.normalize = normalize
|
22 |
+
if scale is not None and normalize is False:
|
23 |
+
raise ValueError("normalize should be True if scale is passed")
|
24 |
+
if scale is None:
|
25 |
+
scale = 2 * math.pi
|
26 |
+
self.scale = scale
|
27 |
+
|
28 |
+
def forward(self, tensor_list: NestedTensor):
|
29 |
+
x = tensor_list.tensors
|
30 |
+
mask = tensor_list.mask
|
31 |
+
assert mask is not None
|
32 |
+
not_mask = ~mask
|
33 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
34 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
35 |
+
if self.normalize:
|
36 |
+
eps = 1e-6
|
37 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
38 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
39 |
+
|
40 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
41 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
42 |
+
|
43 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
44 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
45 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
46 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
47 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
48 |
+
return pos
|
49 |
+
|
50 |
+
|
51 |
+
class PositionEmbeddingLearned(nn.Module):
|
52 |
+
"""
|
53 |
+
Absolute pos embedding, learned.
|
54 |
+
"""
|
55 |
+
def __init__(self, num_pos_feats=256):
|
56 |
+
super().__init__()
|
57 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
58 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
59 |
+
self.reset_parameters()
|
60 |
+
|
61 |
+
def reset_parameters(self):
|
62 |
+
nn.init.uniform_(self.row_embed.weight)
|
63 |
+
nn.init.uniform_(self.col_embed.weight)
|
64 |
+
|
65 |
+
def forward(self, tensor_list: NestedTensor):
|
66 |
+
x = tensor_list.tensors
|
67 |
+
h, w = x.shape[-2:]
|
68 |
+
i = torch.arange(w, device=x.device)
|
69 |
+
j = torch.arange(h, device=x.device)
|
70 |
+
x_emb = self.col_embed(i)
|
71 |
+
y_emb = self.row_embed(j)
|
72 |
+
pos = torch.cat([
|
73 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
74 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
75 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
76 |
+
return pos
|
77 |
+
|
78 |
+
|
79 |
+
def build_position_encoding(args):
|
80 |
+
N_steps = args.hidden_dim // 2
|
81 |
+
if args.position_embedding in ('v2', 'sine'):
|
82 |
+
# TODO find a better way of exposing other arguments
|
83 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
84 |
+
elif args.position_embedding in ('v3', 'learned'):
|
85 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
86 |
+
else:
|
87 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
88 |
+
|
89 |
+
return position_embedding
|
models/preprocessing.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.transforms.functional as functional
|
2 |
+
|
3 |
+
class Compose(object):
|
4 |
+
def __init__(self, transforms):
|
5 |
+
self.transforms = transforms
|
6 |
+
|
7 |
+
def __call__(self, image):
|
8 |
+
for t in self.transforms:
|
9 |
+
image = t(image)
|
10 |
+
return image
|
11 |
+
|
12 |
+
def __repr__(self):
|
13 |
+
format_string = self.__class__.__name__ + "("
|
14 |
+
for t in self.transforms:
|
15 |
+
format_string += "\n"
|
16 |
+
format_string += " {0}".format(t)
|
17 |
+
format_string += "\n)"
|
18 |
+
return format_string
|
19 |
+
|
20 |
+
class Normalize(object):
|
21 |
+
def __init__(self, mean, std):
|
22 |
+
self.mean = mean
|
23 |
+
self.std = std
|
24 |
+
|
25 |
+
def __call__(self, image):
|
26 |
+
image = functional.normalize(image, mean=self.mean, std=self.std)
|
27 |
+
return image
|
28 |
+
|
29 |
+
class ToTensor(object):
|
30 |
+
def __call__(self, img):
|
31 |
+
return functional.to_tensor(img)
|
32 |
+
|
33 |
+
def resize(image, size, max_size=None):
|
34 |
+
# size can be min_size (scalar) or (w, h) tuple
|
35 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
36 |
+
w, h = image_size
|
37 |
+
if max_size is not None:
|
38 |
+
min_original_size = float(min((w, h)))
|
39 |
+
max_original_size = float(max((w, h)))
|
40 |
+
if max_original_size / min_original_size * size > max_size:
|
41 |
+
size = int(round(max_size * min_original_size / max_original_size))
|
42 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
43 |
+
return (h, w)
|
44 |
+
if w < h:
|
45 |
+
ow = size
|
46 |
+
oh = int(size * h / w)
|
47 |
+
else:
|
48 |
+
oh = size
|
49 |
+
ow = int(size * w / h)
|
50 |
+
return (oh, ow)
|
51 |
+
|
52 |
+
def get_size(image_size, size, max_size=None):
|
53 |
+
if isinstance(size, (list, tuple)):
|
54 |
+
return size[::-1]
|
55 |
+
else:
|
56 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
57 |
+
|
58 |
+
size = get_size(image.size, size, max_size)
|
59 |
+
rescaled_image = functional.resize(image, size)
|
60 |
+
|
61 |
+
return rescaled_image
|
62 |
+
|
63 |
+
class Resize(object):
|
64 |
+
def __init__(self, sizes, max_size=None):
|
65 |
+
assert isinstance(sizes, (list, tuple))
|
66 |
+
self.sizes = sizes
|
67 |
+
self.max_size = max_size
|
68 |
+
|
69 |
+
def __call__(self, img):
|
70 |
+
size = self.sizes
|
71 |
+
return resize(img, size, self.max_size)
|
models/transformer.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
DETR Transformer class.
|
4 |
+
|
5 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
6 |
+
* positional encodings are passed in MHattention
|
7 |
+
* extra LN at the end of encoder is removed
|
8 |
+
* decoder returns a stack of activations from all decoding layers
|
9 |
+
"""
|
10 |
+
import copy
|
11 |
+
from typing import Optional, List
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch import nn, Tensor
|
16 |
+
from .multi_head_attention import MultiheadAttention
|
17 |
+
|
18 |
+
class Transformer(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
21 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
22 |
+
activation="relu", normalize_before=False,
|
23 |
+
return_intermediate_dec=False):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
27 |
+
dropout, activation, normalize_before)
|
28 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
29 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
30 |
+
|
31 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
32 |
+
dropout, activation, normalize_before)
|
33 |
+
decoder_norm = nn.LayerNorm(d_model)
|
34 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
35 |
+
return_intermediate=return_intermediate_dec)
|
36 |
+
|
37 |
+
self._reset_parameters()
|
38 |
+
|
39 |
+
self.d_model = d_model
|
40 |
+
self.nhead = nhead
|
41 |
+
|
42 |
+
def _reset_parameters(self):
|
43 |
+
for p in self.parameters():
|
44 |
+
if p.dim() > 1:
|
45 |
+
nn.init.xavier_uniform_(p)
|
46 |
+
|
47 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
48 |
+
# flatten NxCxHxW to HWxNxC
|
49 |
+
bs, c, h, w = src.shape
|
50 |
+
src = src.flatten(2).permute(2, 0, 1)
|
51 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
52 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
53 |
+
mask = mask.flatten(1)
|
54 |
+
|
55 |
+
tgt = torch.zeros_like(query_embed)
|
56 |
+
|
57 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
58 |
+
|
59 |
+
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
60 |
+
pos=pos_embed, query_pos=query_embed)
|
61 |
+
return hs.transpose(1, 2), memory#.permute(1, 2, 0).view(bs, c, h, w)
|
62 |
+
|
63 |
+
|
64 |
+
class TransformerEncoder(nn.Module):
|
65 |
+
|
66 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
67 |
+
super().__init__()
|
68 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
69 |
+
self.num_layers = num_layers
|
70 |
+
self.norm = norm
|
71 |
+
|
72 |
+
def forward(self, src,
|
73 |
+
mask: Optional[Tensor] = None,
|
74 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
75 |
+
pos: Optional[Tensor] = None):
|
76 |
+
output = src
|
77 |
+
|
78 |
+
for layer in self.layers:
|
79 |
+
output = layer(output, src_mask=mask,
|
80 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
81 |
+
|
82 |
+
if self.norm is not None:
|
83 |
+
output = self.norm(output)
|
84 |
+
|
85 |
+
return output
|
86 |
+
|
87 |
+
class TransformerDecoder(nn.Module):
|
88 |
+
|
89 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
90 |
+
super().__init__()
|
91 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
92 |
+
self.num_layers = num_layers
|
93 |
+
self.norm = norm
|
94 |
+
self.return_intermediate = return_intermediate
|
95 |
+
|
96 |
+
def forward(self, tgt, memory,
|
97 |
+
tgt_mask: Optional[Tensor] = None,
|
98 |
+
memory_mask: Optional[Tensor] = None,
|
99 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
100 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
101 |
+
pos: Optional[Tensor] = None,
|
102 |
+
query_pos: Optional[Tensor] = None):
|
103 |
+
output = tgt
|
104 |
+
|
105 |
+
intermediate = []
|
106 |
+
|
107 |
+
for layer in self.layers:
|
108 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
109 |
+
memory_mask=memory_mask,
|
110 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
111 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
112 |
+
pos=pos, query_pos=query_pos)
|
113 |
+
if self.return_intermediate:
|
114 |
+
intermediate.append(self.norm(output))
|
115 |
+
|
116 |
+
if self.norm is not None:
|
117 |
+
output = self.norm(output)
|
118 |
+
if self.return_intermediate:
|
119 |
+
intermediate.pop()
|
120 |
+
intermediate.append(output)
|
121 |
+
|
122 |
+
if self.return_intermediate:
|
123 |
+
return torch.stack(intermediate)
|
124 |
+
|
125 |
+
return output.unsqueeze(0)
|
126 |
+
|
127 |
+
|
128 |
+
class TransformerEncoderLayer(nn.Module):
|
129 |
+
|
130 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
131 |
+
activation="relu", normalize_before=False):
|
132 |
+
super().__init__()
|
133 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
134 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
135 |
+
self.dropout = nn.Dropout(dropout)
|
136 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
137 |
+
|
138 |
+
self.norm1 = nn.LayerNorm(d_model)
|
139 |
+
self.norm2 = nn.LayerNorm(d_model)
|
140 |
+
self.dropout1 = nn.Dropout(dropout)
|
141 |
+
self.dropout2 = nn.Dropout(dropout)
|
142 |
+
|
143 |
+
self.activation = _get_activation_fn(activation)
|
144 |
+
self.normalize_before = normalize_before
|
145 |
+
|
146 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
147 |
+
return tensor if pos is None else tensor + pos
|
148 |
+
|
149 |
+
def forward_post(self,
|
150 |
+
src,
|
151 |
+
src_mask: Optional[Tensor] = None,
|
152 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
153 |
+
pos: Optional[Tensor] = None):
|
154 |
+
q = k = self.with_pos_embed(src, pos)
|
155 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
156 |
+
key_padding_mask=src_key_padding_mask)[0]
|
157 |
+
src = src + self.dropout1(src2)
|
158 |
+
src = self.norm1(src)
|
159 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
160 |
+
src = src + self.dropout2(src2)
|
161 |
+
src = self.norm2(src)
|
162 |
+
return src
|
163 |
+
|
164 |
+
def forward_pre(self, src,
|
165 |
+
src_mask: Optional[Tensor] = None,
|
166 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
167 |
+
pos: Optional[Tensor] = None):
|
168 |
+
src2 = self.norm1(src)
|
169 |
+
q = k = self.with_pos_embed(src2, pos)
|
170 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
171 |
+
key_padding_mask=src_key_padding_mask)[0]
|
172 |
+
src = src + self.dropout1(src2)
|
173 |
+
src2 = self.norm2(src)
|
174 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
175 |
+
src = src + self.dropout2(src2)
|
176 |
+
return src
|
177 |
+
|
178 |
+
def forward(self, src,
|
179 |
+
src_mask: Optional[Tensor] = None,
|
180 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
181 |
+
pos: Optional[Tensor] = None):
|
182 |
+
if self.normalize_before:
|
183 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
184 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
185 |
+
|
186 |
+
|
187 |
+
class TransformerDecoderLayer(nn.Module):
|
188 |
+
|
189 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
190 |
+
activation="relu", normalize_before=False):
|
191 |
+
super().__init__()
|
192 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
193 |
+
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
194 |
+
# Implementation of Feedforward model
|
195 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
196 |
+
self.dropout = nn.Dropout(dropout)
|
197 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
198 |
+
|
199 |
+
self.norm1 = nn.LayerNorm(d_model)
|
200 |
+
self.norm2 = nn.LayerNorm(d_model)
|
201 |
+
self.norm3 = nn.LayerNorm(d_model)
|
202 |
+
self.dropout1 = nn.Dropout(dropout)
|
203 |
+
self.dropout2 = nn.Dropout(dropout)
|
204 |
+
self.dropout3 = nn.Dropout(dropout)
|
205 |
+
|
206 |
+
self.activation = _get_activation_fn(activation)
|
207 |
+
self.normalize_before = normalize_before
|
208 |
+
|
209 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
210 |
+
return tensor if pos is None else tensor + pos
|
211 |
+
|
212 |
+
def forward_post(self, tgt, memory,
|
213 |
+
tgt_mask: Optional[Tensor] = None,
|
214 |
+
memory_mask: Optional[Tensor] = None,
|
215 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
216 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
217 |
+
pos: Optional[Tensor] = None,
|
218 |
+
query_pos: Optional[Tensor] = None):
|
219 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
220 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
221 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
222 |
+
tgt = tgt + self.dropout1(tgt2)
|
223 |
+
tgt = self.norm1(tgt)
|
224 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
225 |
+
key=self.with_pos_embed(memory, pos),
|
226 |
+
value=memory, attn_mask=memory_mask,
|
227 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
228 |
+
tgt = tgt + self.dropout2(tgt2)
|
229 |
+
tgt = self.norm2(tgt)
|
230 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
231 |
+
tgt = tgt + self.dropout3(tgt2)
|
232 |
+
tgt = self.norm3(tgt)
|
233 |
+
return tgt
|
234 |
+
|
235 |
+
def forward_pre(self, tgt, memory,
|
236 |
+
tgt_mask: Optional[Tensor] = None,
|
237 |
+
memory_mask: Optional[Tensor] = None,
|
238 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
239 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
240 |
+
pos: Optional[Tensor] = None,
|
241 |
+
query_pos: Optional[Tensor] = None):
|
242 |
+
tgt2 = self.norm1(tgt)
|
243 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
244 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
245 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
246 |
+
tgt = tgt + self.dropout1(tgt2)
|
247 |
+
tgt2 = self.norm2(tgt)
|
248 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
249 |
+
key=self.with_pos_embed(memory, pos),
|
250 |
+
value=memory, attn_mask=memory_mask,
|
251 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
252 |
+
tgt = tgt + self.dropout2(tgt2)
|
253 |
+
tgt2 = self.norm3(tgt)
|
254 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
255 |
+
tgt = tgt + self.dropout3(tgt2)
|
256 |
+
return tgt
|
257 |
+
|
258 |
+
def forward(self, tgt, memory,
|
259 |
+
tgt_mask: Optional[Tensor] = None,
|
260 |
+
memory_mask: Optional[Tensor] = None,
|
261 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
262 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
263 |
+
pos: Optional[Tensor] = None,
|
264 |
+
query_pos: Optional[Tensor] = None):
|
265 |
+
if self.normalize_before:
|
266 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
267 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
268 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
269 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
270 |
+
|
271 |
+
|
272 |
+
def _get_clones(module, N):
|
273 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
274 |
+
|
275 |
+
|
276 |
+
def build_transformer(args):
|
277 |
+
|
278 |
+
return Transformer(
|
279 |
+
d_model=args.hidden_dim,
|
280 |
+
dropout=args.dropout,
|
281 |
+
nhead=args.nheads,
|
282 |
+
dim_feedforward=args.dim_feedforward,
|
283 |
+
num_encoder_layers=args.enc_layers,
|
284 |
+
num_decoder_layers=args.dec_layers,
|
285 |
+
normalize_before=args.pre_norm,
|
286 |
+
return_intermediate_dec=True,
|
287 |
+
)
|
288 |
+
|
289 |
+
def _get_activation_fn(activation):
|
290 |
+
"""Return an activation function given a string"""
|
291 |
+
if activation == "relu":
|
292 |
+
return F.relu
|
293 |
+
if activation == "gelu":
|
294 |
+
return F.gelu
|
295 |
+
if activation == "glu":
|
296 |
+
return F.glu
|
297 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.8.1
|
2 |
+
torchvision
|
3 |
+
gradio
|
4 |
+
jinja2
|
5 |
+
scipy
|
tappeto-per-calibrazione.jpg
ADDED
test.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw
|
2 |
+
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from models.letr import build
|
7 |
+
from models.misc import nested_tensor_from_tensor_list
|
8 |
+
from models.preprocessing import Compose, ToTensor, Resize, Normalize
|
9 |
+
|
10 |
+
def create_letr():
|
11 |
+
# obtain checkpoints
|
12 |
+
checkpoint = torch.load('checkpoint0024.pth', map_location='cpu')
|
13 |
+
|
14 |
+
# load model
|
15 |
+
args = checkpoint['args']
|
16 |
+
args.device = 'cpu'
|
17 |
+
model, _, _ = build(args)
|
18 |
+
model.load_state_dict(checkpoint['model'])
|
19 |
+
model.eval()
|
20 |
+
return model
|
21 |
+
|
22 |
+
def draw_fig(image, outputs, orig_size):
|
23 |
+
# find lines
|
24 |
+
out_logits, out_line = outputs['pred_logits'], outputs['pred_lines']
|
25 |
+
prob = F.softmax(out_logits, -1)
|
26 |
+
scores, labels = prob[..., :-1].max(-1)
|
27 |
+
img_h, img_w = orig_size.unbind(0)
|
28 |
+
scale_fct = torch.unsqueeze(torch.stack(
|
29 |
+
[img_w, img_h, img_w, img_h], dim=0), dim=0)
|
30 |
+
lines = out_line * scale_fct[:, None, :]
|
31 |
+
lines = lines.view(1000, 2, 2)
|
32 |
+
lines = lines.flip([-1]) # this is yxyx format
|
33 |
+
scores = scores.detach().numpy()
|
34 |
+
keep = scores >= 0.7
|
35 |
+
keep = keep.squeeze()
|
36 |
+
lines = lines[keep]
|
37 |
+
if len(lines) != 0:
|
38 |
+
lines = lines.reshape(lines.shape[0], -1)
|
39 |
+
|
40 |
+
# draw lines
|
41 |
+
draw = ImageDraw.Draw(image)
|
42 |
+
for tp_id, line in enumerate(lines):
|
43 |
+
y1, x1, y2, x2 = line
|
44 |
+
draw.line((x1, y1, x2, y2), fill=500)
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
model = create_letr()
|
48 |
+
|
49 |
+
test_size = 256
|
50 |
+
normalize = Compose([
|
51 |
+
ToTensor(),
|
52 |
+
Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
|
53 |
+
Resize([test_size]),
|
54 |
+
])
|
55 |
+
|
56 |
+
image = Image.open('demo.png')
|
57 |
+
h, w = image.height, image.width
|
58 |
+
orig_size = torch.as_tensor([int(h), int(w)])
|
59 |
+
|
60 |
+
img = normalize(image)
|
61 |
+
inputs = nested_tensor_from_tensor_list([img])
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
outputs = model(inputs)[0]
|
65 |
+
draw_fig(image, outputs, orig_size)
|
66 |
+
|
67 |
+
image.save('output.png')
|