|
import torch |
|
from transformers import PreTrainedModel |
|
from .dmae_config import DMAE1dConfig |
|
from audio_encoders_pytorch import ME1d, TanhBottleneck |
|
from audio_diffusion_pytorch import UNetV0, LTPlugin, DiffusionAE |
|
|
|
class DMAE1d(PreTrainedModel): |
|
|
|
config_class = DMAE1dConfig |
|
|
|
def __init__(self, config: DMAE1dConfig): |
|
super().__init__(config) |
|
|
|
UNet = LTPlugin( |
|
UNetV0, |
|
num_filters=128, |
|
window_length=64, |
|
stride=64, |
|
) |
|
|
|
self.model = DiffusionAE( |
|
net_t=UNet, |
|
dim=1, |
|
in_channels=2, |
|
channels=[256, 512, 512, 512, 1024, 1024, 1024], |
|
factors=[1, 2, 2, 2, 2, 2, 2], |
|
items=[1, 2, 2, 2, 2, 2, 2], |
|
encoder=ME1d( |
|
in_channels=2, |
|
channels=512, |
|
multipliers=[1, 1, 1], |
|
factors=[2, 2], |
|
num_blocks=[4, 8], |
|
stft_num_fft=1023, |
|
stft_hop_length=256, |
|
out_channels=32, |
|
bottleneck=TanhBottleneck() |
|
), |
|
inject_depth=4 |
|
) |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.model(*args, **kwargs) |
|
|
|
def encode(self, *args, **kwargs): |
|
return self.model.encode(*args, **kwargs) |
|
|
|
@torch.no_grad() |
|
def decode(self, *args, **kwargs): |
|
return self.model.decode(*args, **kwargs) |
|
|
|
|