Spaces:
Running
on
T4
Running
on
T4
import torch | |
import typing as tp | |
from .transformer import StreamingTransformer, create_sin_embedding | |
class UnetTransformer(StreamingTransformer): | |
"""U-net Transformer for processing sequences with optional skip connections. | |
This transformer architecture incorporates U-net style skip connections | |
between layers, which can be optionally enabled. It inherits from a | |
StreamingTransformer. | |
Args: | |
d_model (int): Dimension of the model, typically the number of expected features in the input. | |
num_layers (int): Total number of layers in the transformer. | |
skip_connections (bool, optional): Flag to determine whether skip connections should be used. | |
Defaults to False. | |
layer_dropout_p (float, Optional): if given, defined bernoulli prob. to drop a skip connection (in training). | |
**kwargs: Additional keyword arguments inherited from `nn.StreamingTransformer`. | |
""" | |
def __init__(self, d_model: int, num_layers: int, skip_connections: bool = False, | |
layer_dropout_p: tp.Optional[float] = None, **kwargs): | |
super().__init__(d_model=d_model, | |
num_layers=num_layers, | |
**kwargs) | |
self.skip_connect = skip_connections | |
if self.skip_connect: | |
self.skip_projections = torch.nn.ModuleList([torch.nn.Linear(d_model * 2, d_model) | |
for _ in range(num_layers // 2)]) | |
self.num_layers = num_layers | |
self.layer_drop_p = max(min(layer_dropout_p, 1.), 0.) if layer_dropout_p is not None else 0.0 | |
def forward(self, x: torch.Tensor, *args, **kwargs): | |
B, T, C = x.shape | |
if 'offsets' in self._streaming_state: | |
offsets = self._streaming_state['offsets'] | |
else: | |
offsets = torch.zeros(B, dtype=torch.long, device=x.device) | |
if self.positional_embedding in ['sin', 'sin_rope']: | |
positions = torch.arange(T, device=x.device).view(1, -1, 1) | |
positions = positions + offsets.view(-1, 1, 1) | |
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) | |
x = x + self.positional_scale * pos_emb | |
skip_connections: tp.List[torch.Tensor] = [] | |
for i, layer in enumerate(self.layers): | |
if self.skip_connect and i >= self.num_layers // 2: | |
# in the second half of the layers, add residual connection | |
# and linearly project the concatenated features back to d_model | |
x = torch.cat([x, skip_connections.pop()], dim=-1) | |
x = self.skip_projections[i % len(self.skip_projections)](x) | |
x = self._apply_layer(layer, x, *args, **kwargs) | |
if self.skip_connect and i < self.num_layers // 2: | |
if self.training and torch.rand(1,) < self.layer_drop_p: # drop skip | |
skip_connections.append(torch.zeros_like(x)) | |
else: | |
skip_connections.append(x) | |
if self._is_streaming: | |
self._streaming_state['offsets'] = offsets + T | |
return x | |