import sys import torch from torch import nn import numpy as np from vq.module import WNConv1d, EncoderBlock, ResLSTM from vq.alias_free_torch import * from vq import activations from vq.bs_roformer5 import TransformerBlock from torchtune.modules import RotaryPositionalEmbeddings import vq.blocks as blocks from torch.nn import utils def init_weights(m): if isinstance(m, nn.Conv1d): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) class CodecEncoder(nn.Module): def __init__(self, ngf=48, use_rnn=True, rnn_bidirectional=False, rnn_num_layers=2, up_ratios=(2, 2, 4, 4, 5), dilations=(1, 3, 9), out_channels=1024): super().__init__() self.hop_length = np.prod(up_ratios) self.ngf = ngf self.up_ratios = up_ratios # Create first convolution d_model = ngf self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] # Create EncoderBlocks that double channels as they downsample by `stride` for i, stride in enumerate(up_ratios): d_model *= 2 self.block += [EncoderBlock(d_model, stride=stride, dilations=dilations)] # RNN if use_rnn: self.block += [ ResLSTM(d_model, num_layers=rnn_num_layers, bidirectional=rnn_bidirectional ) ] # Create last convolution self.block += [ Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), WNConv1d(d_model, out_channels, kernel_size=3, padding=1), ] # Wrap black into nn.Sequential self.block = nn.Sequential(*self.block) self.enc_dim = d_model self.reset_parameters() def forward(self, x): out = self.block(x) return out def inference(self, x): return self.block(x) def remove_weight_norm(self): """Remove weight normalization module from all of the layers.""" def _remove_weight_norm(m): try: torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(_remove_weight_norm) def apply_weight_norm(self): """Apply weight normalization module from all of the layers.""" def _apply_weight_norm(m): if isinstance(m, nn.Conv1d): torch.nn.utils.weight_norm(m) self.apply(_apply_weight_norm) def reset_parameters(self): self.apply(init_weights) class Transpose(nn.Module): def __init__(self, dim1, dim2): super(Transpose, self).__init__() self.dim1 = dim1 self.dim2 = dim2 def forward(self, x): return x.transpose(self.dim1, self.dim2) class CodecEncoder_Transformer(nn.Module): def __init__(self, ngf=48, up_ratios=[2, 2, 4, 4, 5], dilations=(1, 3, 9), hidden_dim=1024, depth=12, heads=12, pos_meb_dim=64, ): super().__init__() self.hop_length = np.prod(up_ratios) self.ngf =ngf self.up_ratios = up_ratios d_model = ngf self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)] for i, stride in enumerate(up_ratios): d_model *= 2 self.conv_blocks += [EncoderBlock(d_model, stride=stride, dilations=dilations)] self.conv_blocks = nn.Sequential(*self.conv_blocks) # time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) # transformer_blocks = [ # TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) # for _ in range(depth) # ] # self.transformers = nn.Sequential(*transformer_blocks) # self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) self.conv_final_block = [ Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1), ] self.conv_final_block = nn.Sequential(*self.conv_final_block) self.reset_parameters() def forward(self, x): x = self.conv_blocks(x) # x = x.permute(0, 2, 1) # x= self.transformers(x) # x = self.final_layer_norm(x) # x = x.permute(0, 2, 1) x = self.conv_final_block (x) x = x.permute(0, 2, 1) return x def inference(self, x): return self.block(x) def remove_weight_norm(self): """Remove weight normalization module from all of the layers.""" def _remove_weight_norm(m): try: torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(_remove_weight_norm) def apply_weight_norm(self): """Apply weight normalization module from all of the layers.""" def _apply_weight_norm(m): if isinstance(m, nn.Conv1d): torch.nn.utils.weight_norm(m) self.apply(_apply_weight_norm) def reset_parameters(self): self.apply(init_weights) class Codec_oobleck_Transformer(nn.Module): def __init__(self, ngf=32, up_ratios=(2, 2,4,4, 5), dilations=(1, 3, 9), hidden_dim=1024, depth=12, heads=16, pos_meb_dim=64, ): super().__init__() self.hop_length = np.prod(up_ratios) self.ngf =ngf self.up_ratios = up_ratios self.hidden_dim = hidden_dim self.conv_blocks = blocks.DilatedResidualEncoder( capacity=ngf, dilated_unit=self.dilated_unit, downsampling_unit=self.downsampling_unit, ratios=up_ratios, dilations=dilations, pre_network_conv=self.pre_conv, post_network_conv=self.post_conv, ) time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) transformer_blocks = [ TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) for _ in range(depth) ] self.transformers = nn.Sequential(*transformer_blocks) self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) self.reset_parameters() def forward(self, x): x = self.conv_blocks(x) x = x.permute(0, 2, 1) x= self.transformers(x) x = self.final_layer_norm(x) return x def inference(self, x): return self.block(x) def remove_weight_norm(self): """Remove weight normalization module from all of the layers.""" def _remove_weight_norm(m): try: torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(_remove_weight_norm) def apply_weight_norm(self): """Apply weight normalization module from all of the layers.""" def _apply_weight_norm(m): if isinstance(m, nn.Conv1d): torch.nn.utils.weight_norm(m) self.apply(_apply_weight_norm) def reset_parameters(self): self.apply(init_weights) def dilated_unit(self,hidden_dim, dilation): return blocks.DilatedConvolutionalUnit(hidden_dim, dilation, kernel_size=3, activation=nn.ReLU, normalization=utils.weight_norm) def downsampling_unit(self, input_dim: int, output_dim: int, stride: int): return blocks.DownsamplingUnit(input_dim, output_dim, stride, nn.ReLU, normalization=utils.weight_norm) def pre_conv(self,out_channels): return nn.Conv1d(1, out_channels, 1) def post_conv(self,in_channels): return nn.Conv1d(in_channels, self.hidden_dim, 1) class CodecEncoder_only_Transformer(nn.Module): def __init__(self,hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): super().__init__() # self.embed = nn.Linear(input_dim, hidden_dim )input_dim=300, depth = depth time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) transformer_blocks = [ TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) for _ in range(depth) ] self.transformers = nn.Sequential(*transformer_blocks) self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) def forward(self, x: torch.Tensor ) -> torch.Tensor: # x = self.embed(x) x= self.transformers(x) x = self.final_layer_norm(x) return x def get_model_size(model): # 计算总参数数 total_params = sum(p.numel() for p in model.parameters()) # 假设每个参数都是32位浮点数,计算模型大小(以字节为单位) model_size_bytes = total_params # 每个参数4字节 # 转换为更易读的单位(例如,MB) model_size_mb = model_size_bytes / (1024 ** 2) return total_params, model_size_mb if __name__ == '__main__': model = Codec_oobleck_Transformer() x = torch.randn(1, 1, 16000) # example input tensor output = model(x) print("Output shape:", output.shape)