File size: 3,190 Bytes
907a484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import typing as tp
from .transformer import StreamingTransformer, create_sin_embedding


class UnetTransformer(StreamingTransformer):
    """U-net Transformer for processing sequences with optional skip connections.
    This transformer architecture incorporates U-net style skip connections
    between layers, which can be optionally enabled. It inherits from a
    StreamingTransformer.

    Args:
        d_model (int): Dimension of the model, typically the number of expected features in the input.
        num_layers (int): Total number of layers in the transformer.
        skip_connections (bool, optional): Flag to determine whether skip connections should be used.
                                           Defaults to False.
        layer_dropout_p (float, Optional): if given, defined bernoulli prob. to drop a skip connection (in training).
        **kwargs: Additional keyword arguments inherited from `nn.StreamingTransformer`.
    """
    def __init__(self, d_model: int, num_layers: int, skip_connections: bool = False,
                 layer_dropout_p: tp.Optional[float] = None, **kwargs):
        super().__init__(d_model=d_model,
                         num_layers=num_layers,
                         **kwargs)
        self.skip_connect = skip_connections
        if self.skip_connect:
            self.skip_projections = torch.nn.ModuleList([torch.nn.Linear(d_model * 2, d_model)
                                                        for _ in range(num_layers // 2)])
        self.num_layers = num_layers
        self.layer_drop_p = max(min(layer_dropout_p, 1.), 0.) if layer_dropout_p is not None else 0.0

    def forward(self, x: torch.Tensor, *args, **kwargs):
        B, T, C = x.shape

        if 'offsets' in self._streaming_state:
            offsets = self._streaming_state['offsets']
        else:
            offsets = torch.zeros(B, dtype=torch.long, device=x.device)

        if self.positional_embedding in ['sin', 'sin_rope']:
            positions = torch.arange(T, device=x.device).view(1, -1, 1)
            positions = positions + offsets.view(-1, 1, 1)
            pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
            x = x + self.positional_scale * pos_emb

        skip_connections: tp.List[torch.Tensor] = []

        for i, layer in enumerate(self.layers):
            if self.skip_connect and i >= self.num_layers // 2:

                # in the second half of the layers, add residual connection
                # and linearly project the concatenated features back to d_model
                x = torch.cat([x, skip_connections.pop()], dim=-1)
                x = self.skip_projections[i % len(self.skip_projections)](x)

            x = self._apply_layer(layer, x, *args, **kwargs)

            if self.skip_connect and i < self.num_layers // 2:
                if self.training and torch.rand(1,) < self.layer_drop_p:  # drop skip
                    skip_connections.append(torch.zeros_like(x))
                else:
                    skip_connections.append(x)

        if self._is_streaming:
            self._streaming_state['offsets'] = offsets + T

        return x