File size: 4,199 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import math
from typing import Any, Tuple

import tensorflow as tf
from tensorflow.keras import layers

from doctr.utils.repr import NestedObject

__all__ = ["PatchEmbedding"]


class PatchEmbedding(layers.Layer, NestedObject):
    """Compute 2D patch embeddings with cls token and positional encoding"""

    def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None:
        super().__init__()
        height, width, _ = input_shape
        self.patch_size = patch_size
        self.interpolate = True if patch_size[0] == patch_size[1] else False
        self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token")
        self.positions = self.add_weight(
            shape=(1, self.num_patches + 1, embed_dim),
            initializer="zeros",
            trainable=True,
            name="positions",
        )
        self.projection = layers.Conv2D(
            filters=embed_dim,
            kernel_size=self.patch_size,
            strides=self.patch_size,
            padding="valid",
            data_format="channels_last",
            use_bias=True,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros",
            name="projection",
        )

    def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
        """100 % borrowed from:
        https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_tf_vit.py

        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py
        """
        seq_len, dim = embeddings.shape[1:]
        num_patches = seq_len - 1

        num_positions = self.positions.shape[1] - 1

        if num_patches == num_positions and height == width:
            return self.positions
        class_pos_embed = self.positions[:, :1]
        patch_pos_embed = self.positions[:, 1:]
        h0 = height // self.patch_size[0]
        w0 = width // self.patch_size[1]
        patch_pos_embed = tf.image.resize(
            images=tf.reshape(
                patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
            ),
            size=(h0, w0),
            method="bilinear",
        )

        shape = patch_pos_embed.shape
        assert h0 == shape[-3], "height of interpolated patch embedding doesn't match"
        assert w0 == shape[-2], "width of interpolated patch embedding doesn't match"

        patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
        return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)

    def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
        B, H, W, C = x.shape
        assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height"
        assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width"
        # patchify image
        patches = self.projection(x, **kwargs)  # (batch_size, num_patches, d_model)
        patches = tf.reshape(patches, (B, self.num_patches, -1))  # (batch_size, num_patches, d_model)

        cls_tokens = tf.repeat(self.cls_token, B, axis=0)  # (batch_size, 1, d_model)
        # concate cls_tokens to patches
        embeddings = tf.concat([cls_tokens, patches], axis=1)  # (batch_size, num_patches + 1, d_model)
        # add positions to embeddings
        if self.interpolate:
            embeddings += self.interpolate_pos_encoding(embeddings, H, W)
        else:
            embeddings += self.positions

        return embeddings  # (batch_size, num_patches + 1, d_model)