UnlimitedMusicGen / audiocraft /modules /unet_transformer.py
Surn's picture
Add STYLE model with upgrades
907a484
import torch
import typing as tp
from .transformer import StreamingTransformer, create_sin_embedding
class UnetTransformer(StreamingTransformer):
"""U-net Transformer for processing sequences with optional skip connections.
This transformer architecture incorporates U-net style skip connections
between layers, which can be optionally enabled. It inherits from a
StreamingTransformer.
Args:
d_model (int): Dimension of the model, typically the number of expected features in the input.
num_layers (int): Total number of layers in the transformer.
skip_connections (bool, optional): Flag to determine whether skip connections should be used.
Defaults to False.
layer_dropout_p (float, Optional): if given, defined bernoulli prob. to drop a skip connection (in training).
**kwargs: Additional keyword arguments inherited from `nn.StreamingTransformer`.
"""
def __init__(self, d_model: int, num_layers: int, skip_connections: bool = False,
layer_dropout_p: tp.Optional[float] = None, **kwargs):
super().__init__(d_model=d_model,
num_layers=num_layers,
**kwargs)
self.skip_connect = skip_connections
if self.skip_connect:
self.skip_projections = torch.nn.ModuleList([torch.nn.Linear(d_model * 2, d_model)
for _ in range(num_layers // 2)])
self.num_layers = num_layers
self.layer_drop_p = max(min(layer_dropout_p, 1.), 0.) if layer_dropout_p is not None else 0.0
def forward(self, x: torch.Tensor, *args, **kwargs):
B, T, C = x.shape
if 'offsets' in self._streaming_state:
offsets = self._streaming_state['offsets']
else:
offsets = torch.zeros(B, dtype=torch.long, device=x.device)
if self.positional_embedding in ['sin', 'sin_rope']:
positions = torch.arange(T, device=x.device).view(1, -1, 1)
positions = positions + offsets.view(-1, 1, 1)
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
x = x + self.positional_scale * pos_emb
skip_connections: tp.List[torch.Tensor] = []
for i, layer in enumerate(self.layers):
if self.skip_connect and i >= self.num_layers // 2:
# in the second half of the layers, add residual connection
# and linearly project the concatenated features back to d_model
x = torch.cat([x, skip_connections.pop()], dim=-1)
x = self.skip_projections[i % len(self.skip_projections)](x)
x = self._apply_layer(layer, x, *args, **kwargs)
if self.skip_connect and i < self.num_layers // 2:
if self.training and torch.rand(1,) < self.layer_drop_p: # drop skip
skip_connections.append(torch.zeros_like(x))
else:
skip_connections.append(x)
if self._is_streaming:
self._streaming_state['offsets'] = offsets + T
return x