File size: 4,893 Bytes
0a97d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
from diffusers import UNet2DModel, UNet2DConditionModel
import yaml
from einops import repeat, rearrange

from typing import Any
from torch import Tensor


def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
    if proba == 1:
        return torch.ones(shape, device=device, dtype=torch.bool)
    elif proba == 0:
        return torch.zeros(shape, device=device, dtype=torch.bool)
    else:
        return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)


class FixedEmbedding(nn.Module):
    def __init__(self, features=128):
        super().__init__()
        self.embedding = nn.Embedding(1, features)

    def forward(self, y):
        B, L, C, device = y.shape[0], y.shape[-2], y.shape[-1], y.device
        embed = self.embedding(torch.zeros(B, device=device).long())
        fixed_embedding = repeat(embed, "b c -> b l c", l=L)
        return fixed_embedding


class DiffVC_Cross(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.unet = UNet2DConditionModel(**self.config['unet'])
        self.unet.set_use_memory_efficient_attention_xformers(True)
        self.cfg_embedding = FixedEmbedding(self.config['unet']['cross_attention_dim'])

        self.context_embedding = nn.Sequential(
            nn.Linear(self.config['unet']['cross_attention_dim'], self.config['unet']['cross_attention_dim']),
            nn.SiLU(),
            nn.Linear(self.config['unet']['cross_attention_dim'], self.config['unet']['cross_attention_dim']))

        self.content_embedding = nn.Sequential(
            nn.Linear(self.config['cls_embedding']['content_dim'], self.config['cls_embedding']['content_hidden']),
            nn.SiLU(),
            nn.Linear(self.config['cls_embedding']['content_hidden'], self.config['cls_embedding']['content_hidden']))

        if self.config['cls_embedding']['use_pitch']:
            self.pitch_control = True
            self.pitch_embedding = nn.Sequential(
                nn.Linear(self.config['cls_embedding']['pitch_dim'], self.config['cls_embedding']['pitch_hidden']),
                nn.SiLU(),
                nn.Linear(self.config['cls_embedding']['pitch_hidden'],
                          self.config['cls_embedding']['pitch_hidden']))

            self.pitch_uncond = nn.Parameter(torch.randn(self.config['cls_embedding']['pitch_hidden']) /
                                             self.config['cls_embedding']['pitch_hidden'] ** 0.5)
        else:
            print('no pitch module')
            self.pitch_control = False

    def forward(self, target, t, content, prompt, prompt_mask=None, pitch=None,

                train_cfg=False, speaker_cfg=0.0, pitch_cfg=0.0):
        B, C, M, L = target.shape
        content = self.content_embedding(content)
        content = repeat(content, "b t c-> b c m t", m=M)
        target = target.to(content.dtype)
        x = torch.cat([target, content], dim=1)

        if self.pitch_control:
            if pitch is not None:
                pitch = self.pitch_embedding(pitch)
            else:
                pitch = repeat(self.pitch_uncond, "c-> b t c", b=B, t=L).to(target.dtype)

        if train_cfg:
            # Randomly mask embedding
            batch_mask = rand_bool(shape=(B, 1, 1), proba=speaker_cfg, device=target.device)
            fixed_embedding = self.cfg_embedding(prompt).to(target.dtype)
            prompt = torch.where(batch_mask, fixed_embedding, prompt)

            if self.pitch_control:
                batch_mask = rand_bool(shape=(B, 1, 1), proba=pitch_cfg, device=target.device)
                pitch_uncond = repeat(self.pitch_uncond, "c-> b t c", b=B, t=L).to(target.dtype)
                pitch = torch.where(batch_mask, pitch_uncond, pitch)

        prompt = self.context_embedding(prompt)

        if self.pitch_control:
            pitch = repeat(pitch, "b t c-> b c m t", m=M)
            x = torch.cat([x, pitch], dim=1)

        output = self.unet(sample=x, timestep=t,
                           encoder_hidden_states=prompt,
                           encoder_attention_mask=prompt_mask)['sample']

        return output


if __name__ == "__main__":
    with open('diffvc_cross_pitch.yaml', 'r') as fp:
        config = yaml.safe_load(fp)
    device = 'cuda'

    model = DiffVC_Cross(config['diffwrap']).to(device)

    x = torch.rand((2, 1, 100, 256)).to(device)
    y = torch.rand((2, 256, 768)).to(device)
    t = torch.randint(0, 1000, (2,)).long().to(device)
    prompt = torch.rand(2, 64, 768).to(device)
    prompt_mask = torch.ones(2, 64).to(device)
    p = torch.rand(2, 256, 1).to(device)

    output = model(x, t, y, prompt, prompt_mask, p, train_cfg=True, speaker_cfg=0.25, pitch_cfg=0.5)