Higobeatz's picture
Initial commit
0a97d6c
raw
history blame
4.29 kB
import torch
import torch.nn as nn
from diffusers import UNet2DModel, UNet2DConditionModel
import yaml
from einops import repeat, rearrange
from typing import Any
from torch import Tensor
def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
if proba == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif proba == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
class DiffVC(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.unet = UNet2DModel(**self.config['unet'])
self.unet.set_use_memory_efficient_attention_xformers(True)
self.speaker_embedding = nn.Sequential(
nn.Linear(self.config['cls_embedding']['speaker_dim'], self.config['cls_embedding']['feature_dim']),
nn.SiLU(),
nn.Linear(self.config['cls_embedding']['feature_dim'], self.config['cls_embedding']['feature_dim']))
self.uncond = nn.Parameter(torch.randn(self.config['cls_embedding']['speaker_dim']) /
self.config['cls_embedding']['speaker_dim'] ** 0.5)
self.content_embedding = nn.Sequential(
nn.Linear(self.config['cls_embedding']['content_dim'], self.config['cls_embedding']['content_hidden']),
nn.SiLU(),
nn.Linear(self.config['cls_embedding']['content_hidden'], self.config['cls_embedding']['content_hidden']))
if self.config['cls_embedding']['use_pitch']:
self.pitch_control = True
self.pitch_embedding = nn.Sequential(
nn.Linear(self.config['cls_embedding']['pitch_dim'], self.config['cls_embedding']['pitch_hidden']),
nn.SiLU(),
nn.Linear(self.config['cls_embedding']['pitch_hidden'],
self.config['cls_embedding']['pitch_hidden']))
self.pitch_uncond = nn.Parameter(torch.randn(self.config['cls_embedding']['pitch_hidden']) /
self.config['cls_embedding']['pitch_hidden'] ** 0.5)
else:
print('no pitch module')
self.pitch_control = False
def forward(self, target, t, content, speaker, pitch,
train_cfg=False, speaker_cfg=0.0, pitch_cfg=0.0):
B, C, M, L = target.shape
content = self.content_embedding(content)
content = repeat(content, "b t c-> b c m t", m=M)
target = target.to(content.dtype)
x = torch.cat([target, content], dim=1)
if self.pitch_control:
if pitch is not None:
pitch = self.pitch_embedding(pitch)
else:
pitch = repeat(self.pitch_uncond, "c-> b t c", b=B, t=L).to(target.dtype)
if train_cfg:
uncond = repeat(self.uncond, "c-> b c", b=B).to(target.dtype)
batch_mask = rand_bool(shape=(B, 1), proba=speaker_cfg, device=target.device)
speaker = torch.where(batch_mask, uncond, speaker)
if self.pitch_control:
batch_mask = rand_bool(shape=(B, 1, 1), proba=pitch_cfg, device=target.device)
pitch_uncond = repeat(self.pitch_uncond, "c-> b t c", b=B, t=L).to(target.dtype)
pitch = torch.where(batch_mask, pitch_uncond, pitch)
speaker = self.speaker_embedding(speaker)
if self.pitch_control:
pitch = repeat(pitch, "b t c-> b c m t", m=M)
x = torch.cat([x, pitch], dim=1)
output = self.unet(sample=x, timestep=t, class_labels=speaker)['sample']
return output
if __name__ == "__main__":
with open('diffvc_base_pitch.yaml', 'r') as fp:
config = yaml.safe_load(fp)
device = 'cuda'
model = DiffVC(config['diffwrap']).to(device)
x = torch.rand((2, 1, 100, 256)).to(device)
y = torch.rand((2, 256, 768)).to(device)
p = torch.rand(2, 256, 1).to(device)
t = torch.randint(0, 1000, (2,)).long().to(device)
spk = torch.rand(2, 256).to(device)
output = model(x, t, y, spk, pitch=p, train_cfg=True, cfg_prob=0.25)