import torch import torch.nn as nn from diffusers import UNet2DModel, UNet2DConditionModel import yaml from einops import repeat, rearrange from typing import Any from torch import Tensor def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: if proba == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif proba == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) class DiffVC(nn.Module): def __init__(self, config): super().__init__() self.config = config self.unet = UNet2DModel(**self.config['unet']) self.unet.set_use_memory_efficient_attention_xformers(True) self.speaker_embedding = nn.Sequential( nn.Linear(self.config['cls_embedding']['speaker_dim'], self.config['cls_embedding']['feature_dim']), nn.SiLU(), nn.Linear(self.config['cls_embedding']['feature_dim'], self.config['cls_embedding']['feature_dim'])) self.uncond = nn.Parameter(torch.randn(self.config['cls_embedding']['speaker_dim']) / self.config['cls_embedding']['speaker_dim'] ** 0.5) self.content_embedding = nn.Sequential( nn.Linear(self.config['cls_embedding']['content_dim'], self.config['cls_embedding']['content_hidden']), nn.SiLU(), nn.Linear(self.config['cls_embedding']['content_hidden'], self.config['cls_embedding']['content_hidden'])) if self.config['cls_embedding']['use_pitch']: self.pitch_control = True self.pitch_embedding = nn.Sequential( nn.Linear(self.config['cls_embedding']['pitch_dim'], self.config['cls_embedding']['pitch_hidden']), nn.SiLU(), nn.Linear(self.config['cls_embedding']['pitch_hidden'], self.config['cls_embedding']['pitch_hidden'])) self.pitch_uncond = nn.Parameter(torch.randn(self.config['cls_embedding']['pitch_hidden']) / self.config['cls_embedding']['pitch_hidden'] ** 0.5) else: print('no pitch module') self.pitch_control = False def forward(self, target, t, content, speaker, pitch, train_cfg=False, speaker_cfg=0.0, pitch_cfg=0.0): B, C, M, L = target.shape content = self.content_embedding(content) content = repeat(content, "b t c-> b c m t", m=M) target = target.to(content.dtype) x = torch.cat([target, content], dim=1) if self.pitch_control: if pitch is not None: pitch = self.pitch_embedding(pitch) else: pitch = repeat(self.pitch_uncond, "c-> b t c", b=B, t=L).to(target.dtype) if train_cfg: uncond = repeat(self.uncond, "c-> b c", b=B).to(target.dtype) batch_mask = rand_bool(shape=(B, 1), proba=speaker_cfg, device=target.device) speaker = torch.where(batch_mask, uncond, speaker) if self.pitch_control: batch_mask = rand_bool(shape=(B, 1, 1), proba=pitch_cfg, device=target.device) pitch_uncond = repeat(self.pitch_uncond, "c-> b t c", b=B, t=L).to(target.dtype) pitch = torch.where(batch_mask, pitch_uncond, pitch) speaker = self.speaker_embedding(speaker) if self.pitch_control: pitch = repeat(pitch, "b t c-> b c m t", m=M) x = torch.cat([x, pitch], dim=1) output = self.unet(sample=x, timestep=t, class_labels=speaker)['sample'] return output if __name__ == "__main__": with open('diffvc_base_pitch.yaml', 'r') as fp: config = yaml.safe_load(fp) device = 'cuda' model = DiffVC(config['diffwrap']).to(device) x = torch.rand((2, 1, 100, 256)).to(device) y = torch.rand((2, 256, 768)).to(device) p = torch.rand(2, 256, 1).to(device) t = torch.randint(0, 1000, (2,)).long().to(device) spk = torch.rand(2, 256).to(device) output = model(x, t, y, spk, pitch=p, train_cfg=True, cfg_prob=0.25)