Spaces:
Sleeping
Sleeping
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
""" projection_layer.py """ | |
from typing import Tuple | |
import math | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.nn import Linear, LayerNorm | |
from einops import rearrange | |
from model.ops import count_parameters | |
class GroupLinearFlatten(nn.Module): | |
""" | |
Implements a grouped linear layer with a flattened output. | |
This module applies individual linear transformations for each group in the input tensor | |
and then flattens the group dimension to produce the final output. It's useful when you | |
have distinct groups in the input tensor and you want separate linear transformations for | |
each of these groups. | |
Args: | |
- in_features (int): The number of input features per group. | |
- flatten_out_features (int): The total number of flattened output features. This value must | |
be divisible by num_groups. The actual number of output features | |
per group is computed as flatten_out_features/num_groups. | |
- num_groups (int): The number of distinct groups in the input tensor. | |
- use_bmm (bool, optional): Whether to use batch matrix multiplication for computation. | |
Default is True. | |
Shape: | |
- Input: (batch_size, sequence_length, num_groups, in_features) | |
- Output: (batch_size, sequence_length, flatten_out_features) | |
Examples: | |
>>> m = GroupLinearFlatten(128, 512, 24) # | |
>>> input = torch.randn(16, 10, 24, 128) # (B, T, C, F) | |
>>> output = m(input) | |
>>> output.size() | |
torch.Size([16, 10, 512]) # (B, T, D) | |
""" | |
def __init__(self, in_features, flatten_out_features, num_groups, use_bmm=True): | |
super().__init__() | |
self.in_features = in_features | |
self.flatten_out_features = flatten_out_features | |
self.num_groups = num_groups | |
self.use_bmm = use_bmm | |
# Assuming flatten_out_features is divisible by num_groups | |
self.out_features_per_group = self.flatten_out_features // self.num_groups | |
# Each group gets its own weights | |
self.weight = nn.Parameter(torch.Tensor(num_groups, self.out_features_per_group, in_features)) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
def forward(self, input): | |
# input shape: (batch, seq_length, groups, in_features) | |
# weight shape: (groups, out_features_per_group, in_features) | |
batch_size, t, k, source_d = input.size() | |
if self.use_bmm: | |
# Reshape input for bmm operation | |
input_reshaped = rearrange(input, 'b t k d -> k d (b t)') | |
# Matrix multiplication: dot((k, out_features_per_group, d), (k, d, b*t)) -> (k, out_features_per_group, b*t) | |
output_bmm = torch.bmm(self.weight, input_reshaped) | |
# Reshape back to original shape and flatten the group dimension | |
output = rearrange(output_bmm, 'k d_out (b t) -> b t (k d_out)', b=batch_size, t=t, k=k) | |
else: | |
output = torch.einsum('bsgi,goi->bsgo', input, self.weight) | |
output = rearrange(output, 'b t k d_out -> b t (k d_out)') | |
return output | |
# class MultiChannelGroupLinear(nn.Module): | |
# """ Not Implemented Yet """ | |
# def __init__(self, in_ch=26, in_dim=128, out_ch=13, out_dim=512): | |
# super().__init__() | |
# self.in_ch = in_ch | |
# self.in_dim = in_dim | |
# self.out_ch = out_ch | |
# self.out_dim = out_dim | |
# self.in_ch_per_group = in_ch // out_ch | |
# self.layer = GroupLinearFlatten(in_features=) | |
class MultiChannelLinearProjection(nn.Module): | |
def __init__(self, in_ch=26, in_dim=128, out_ch=13, out_dim=512): | |
super().__init__() | |
self.in_ch = in_ch | |
self.in_dim = in_dim | |
self.out_ch = out_ch | |
self.out_dim = out_dim | |
self.in_ch_per_group = in_ch // out_ch | |
self.linear_in_ch = in_ch // self.in_ch_per_group | |
self.linear_in_dim = in_dim * self.in_ch_per_group | |
# Reshaped Input shape: (b, t, in_dim//in_ch_per_group, in_dim*in_ch_per_group) | |
# Output shape: (b, t, out_ch, out_dim) | |
if in_dim * self.in_ch_per_group == out_dim: | |
self.linear = nn.Identity() | |
else: | |
self.linear = nn.Linear(in_features=self.linear_in_dim, out_features=out_dim, bias=False) | |
def forward(self, x): | |
""" | |
Args: | |
x: (B, T, C, D) | |
Returns: | |
x: (B, C_target, T, D_target) | |
""" | |
x = rearrange(x, 'b t (c1 c2) d -> b c1 t (c2 d)', c1=self.linear_in_ch, c2=self.in_ch_per_group) | |
return self.linear(x) | |
def get_multi_channel_projection_layer(input_shape: Tuple[int], output_shape: Tuple[int], proj_type: str) -> nn.Module: | |
""" This function returns one of the projection layers for multi-channel models.""" | |
in_ch = input_shape[-2] | |
in_dim = input_shape[-1] | |
out_ch = output_shape[-2] | |
out_dim = output_shape[-1] | |
if proj_type == 'mc_shared_linear': | |
return MultiChannelLinearProjection(in_ch, in_dim, out_ch, out_dim) | |
def test_multi_channel_linear_projection(): | |
x = torch.randn(2, 10, 26, 128) # (b, t, c, d) | |
mclp = MultiChannelLinearProjection(in_ch=26, in_dim=128, out_ch=13, out_dim=256) # actually nn.Identity() | |
assert type(nn.Identity()) == type(mclp.linear) | |
assert mclp(x).shape == (2, 13, 10, 256) # (b, _c, t, _d) | |
x = torch.randn(2, 10, 26, 128) # (b, t, c, d) | |
mclp = MultiChannelLinearProjection(in_ch=26, in_dim=128, out_ch=13, out_dim=512) # actually nn.Identity() | |
assert torch.nn.modules.linear.Linear == type(mclp.linear) | |
assert mclp(x).shape == (2, 13, 10, 512) # (b, _c, t, _d) | |
class FlattenMLP(nn.Module): | |
def __init__(self, in_features, flatten_out_features, num_groups, hidden_dim=None, activation=None): | |
super().__init__() | |
self.in_features = in_features | |
self.num_groups = num_groups | |
# Calculate flattened input dimension | |
self.flat_in_dim = in_features * num_groups | |
if hidden_dim is None: | |
hidden_dim = self.flat_in_dim // 2 | |
self.hidden_dim = hidden_dim | |
# Check if flatten_out_features is divisible by in_features | |
assert flatten_out_features % in_features == 0, "flatten_out_features should be divisible by in_features." | |
# Define layers | |
self.layers = nn.Sequential(nn.Flatten(2, 3), nn.Linear(self.flat_in_dim, hidden_dim), nn.LayerNorm(hidden_dim), | |
activation() if activation else nn.Identity(), nn.Linear(hidden_dim, hidden_dim), | |
nn.LayerNorm(hidden_dim), | |
activation() if activation else nn.Identity(), | |
nn.Linear(hidden_dim, flatten_out_features)) | |
def forward(self, x): | |
# x shape: (batch, seq, num_groups, in_features) | |
return self.layers(x) | |
class LinearProjection(nn.Module): | |
def __init__(self, in_features, flatten_out_features, num_groups): | |
super().__init__() | |
# Calculate flattened input dimension | |
self.flat_in_dim = in_features * num_groups | |
self.projection_layer = nn.Linear(in_features=self.flat_in_dim, out_features=flatten_out_features, bias=False) | |
def forward(self, x): | |
# x shape: (batch, seq, num_groups, in_features) | |
batch_size, t, _, _ = x.size() | |
x_flattened = x.reshape(batch_size, t, -1) # Flattening num_groups and in_features | |
return self.projection_layer(x_flattened) | |
class DepthwiseConvProjection(nn.Module): | |
def __init__(self, in_features, flatten_out_features, num_groups, depth): | |
super().__init__() | |
d_out = flatten_out_features // in_features | |
self.conv = nn.Conv2d(in_channels=num_groups, | |
out_channels=num_groups * d_out, | |
kernel_size=(1, depth), | |
groups=num_groups) | |
self.fc = nn.Linear(num_groups * d_out * (in_features - depth + 1), flatten_out_features) | |
def forward(self, x): | |
# Swap the dimensions of k and t to match expected input for depthwise convolution | |
x = x.permute(0, 2, 1, 3) # shape: (b, k, t, d) | |
# Convolutional layer | |
x = self.conv(x) # shape: (b, k*d_out, t, d-depth+1) | |
# Reshape the tensor for the Linear layer | |
batch_size, _, t, _ = x.size() | |
x = x.reshape(batch_size, t, -1) | |
return self.fc(x) | |
def get_projection_layer(input_shape: Tuple[int], output_shape: Tuple[int], proj_type: str) -> nn.Module: | |
""" This function returns one of the projection layers defined below. """ | |
if len(input_shape) == 2: | |
_, d_source = input_shape | |
elif len(input_shape) == 3: | |
_, k_source, d_source = input_shape | |
if len(output_shape) == 2: | |
_, d_target = output_shape | |
elif len(output_shape) == 3: | |
_, k_target, d_target = output_shape | |
if 'linear' == proj_type: | |
return LinearProjection(in_features=d_source, flatten_out_features=d_target, num_groups=k_source) | |
elif 'mlp' in proj_type: | |
if 'gelu' in proj_type: | |
return FlattenMLP(in_features=d_source, | |
flatten_out_features=d_target, | |
num_groups=k_source, | |
activation=nn.GELU) | |
elif 'relu' in proj_type: | |
return FlattenMLP(in_features=d_source, | |
flatten_out_features=d_target, | |
num_groups=k_source, | |
activation=nn.ReLU) | |
else: | |
return FlattenMLP(in_features=d_source, flatten_out_features=d_target, num_groups=k_source, activation=None) | |
elif 'conv' in proj_type: | |
if 'conv4' == proj_type: | |
return DepthwiseConvProjection(in_features=d_source, | |
flatten_out_features=d_target, | |
num_groups=k_source, | |
depth=4) | |
elif 'conv16' == proj_type: | |
return DepthwiseConvProjection(in_features=d_source, | |
flatten_out_features=d_target, | |
num_groups=k_source, | |
depth=16) | |
elif 'conv32' == proj_type: | |
return DepthwiseConvProjection(in_features=d_source, | |
flatten_out_features=d_target, | |
num_groups=k_source, | |
depth=32) | |
elif 'conv64' == proj_type: | |
return DepthwiseConvProjection(in_features=d_source, | |
flatten_out_features=d_target, | |
num_groups=k_source, | |
depth=64) | |
else: # conv depth 1 | |
return DepthwiseConvProjection(in_features=d_source, | |
flatten_out_features=d_target, | |
num_groups=k_source, | |
depth=1) | |
elif 'group_linear' == proj_type: | |
assert d_source % k_source == 0, "d_source and k_source must be divisible for group_linear projection." | |
return GroupLinearFlatten(in_features=d_source, | |
flatten_out_features=d_target, | |
num_groups=k_source, | |
use_bmm=True) | |
else: | |
raise ValueError(f"Invalid projection type: {proj_type}") | |
def test_projection_layers(): | |
# encoder hidden states: (B, T, K, D) | |
b = 2 | |
t = 110 #10 | |
k = 24 #16 | |
d = 128 | |
enc_hs = torch.randn(b, t, k, d) | |
# target shape: (B, T, K, D//4) | |
target_flatten_d = 512 | |
# GroupLinear | |
gl = GroupLinearFlatten(in_features=d, flatten_out_features=target_flatten_d, num_groups=k, use_bmm=True) | |
enc_hs_hat = gl(enc_hs) | |
assert enc_hs_hat.shape == (b, t, target_flatten_d) | |
print('GroupLinear: ', f'{count_parameters(gl)//1000}k') # 65k | |
# FlattenMLP | |
fm = FlattenMLP(in_features=d, | |
flatten_out_features=target_flatten_d, | |
num_groups=k, | |
hidden_dim=None, | |
activation=nn.GELU) | |
enc_hs_hat = fm(enc_hs) | |
assert enc_hs_hat.shape == (b, t, target_flatten_d) | |
print('FlattenMLP: ', f'{count_parameters(fm)//1000}k') # 3.6M | |
# LinearProjection | |
lp = LinearProjection(in_features=d, flatten_out_features=target_flatten_d, num_groups=k) | |
enc_hs_hat = lp(enc_hs) | |
assert enc_hs_hat.shape == (b, t, target_flatten_d) | |
print('LinearProjection: ', f'{count_parameters(lp)//1000}k') # 1M | |
# DepthwiseConvProjection | |
dc = DepthwiseConvProjection(in_features=d, flatten_out_features=target_flatten_d, num_groups=k, depth=16) | |
enc_hs_hat = dc(enc_hs) | |
assert enc_hs_hat.shape == (b, t, target_flatten_d) | |
print('DepthwiseConvProjection: ', f'{count_parameters(dc)//1000}k') # 4M | |