File size: 3,911 Bytes
9778d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.nn.functional as F
# from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock

from torch import nn, einsum
from einops import rearrange, repeat

#k,q will be from audio

class MyCrossAttention(nn.Module):
    def __init__(self,  device="cuda", audio_dim = 1024, context_dim = 768, dropout=0.0, h = 8, dim_head=40):
        super().__init__()
        self.h = h
        inner_dim = dim_head * h
        self.scale = dim_head ** -0.5

        self.to_q_adapter = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_k_adapter = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v_adapter = nn.Linear(context_dim, inner_dim, bias=False)

    def forward(self, audio):
        q_adapter = self.to_q_adapter(audio) #from text
        k_adapter = self.to_k_adapter(audio)
        v_adapter = self.to_v_adapter(audio)

        q_adapter, k_adapter, v_adapter = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.h), (q_adapter, k_adapter, v_adapter))

        sim_adapter = einsum('b i d, b j d -> b i j', q_adapter, k_adapter) * self.scale
    
        attn_adapter = sim_adapter.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn_adapter, v_adapter)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=self.h)
        # print(f'ca out shape is: {out.shape}')

        return out


class Adapter(nn.Module):
    def __init__(self,  device="cuda", audio_dim = 1024, context_dim = 768, dropout=0.0, h = 8, dim_head=40, audio_token_count = 10, initial_channel_dim=1, transformer_layer_count=4):
        super(Adapter, self).__init__()
        self.h = h
        inner_dim = dim_head * h

        audio_att_inner_dim = audio_token_count

        self.audio_emb_projection = nn.Sequential(
            nn.Conv1d(initial_channel_dim, audio_att_inner_dim, kernel_size = 17, stride = 1, padding = 8),
            nn.GELU(),
            nn.Conv1d(audio_att_inner_dim, audio_att_inner_dim, kernel_size = 17, stride = 1, padding = 8),
            nn.GELU(),
            nn.LayerNorm([audio_att_inner_dim, audio_dim]),
            nn.Conv1d(audio_att_inner_dim, audio_att_inner_dim, kernel_size = 17, stride = 1, padding = 8),
            nn.GELU(), 
            nn.LayerNorm([audio_att_inner_dim, audio_dim]),
            nn.ConvTranspose1d(audio_att_inner_dim, audio_att_inner_dim, kernel_size = 17, stride=3, padding=7),
            nn.GELU(), 
            nn.LayerNorm([audio_att_inner_dim, 3*audio_dim]),
            nn.GELU(), 
            nn.Conv1d(audio_att_inner_dim, audio_att_inner_dim, kernel_size = 17, stride=4, padding=7),
            nn.Dropout(dropout)
        )

        #create a stack of MyCrossAttention layers
        self.cross_attention = nn.ModuleList([MyCrossAttention(device, audio_dim, context_dim, dropout, h, dim_head) for _ in range(transformer_layer_count)])
        
        #create a stack of linear, gelu, linear dropout layers to be used after the cross attention
        self.between_attention = nn.ModuleList([nn.Sequential(
            nn.Linear(inner_dim, inner_dim),
            nn.GELU(),
            nn.Linear(inner_dim, context_dim),
            nn.Dropout(dropout)
            ) for _ in range(transformer_layer_count)])

        self.to_out_adapter = nn.Sequential(
        nn.Linear(context_dim, context_dim),
        nn.Dropout(dropout)
        )
  

    def forward(self, audio_context):
        audio_proj = self.audio_emb_projection(audio_context) #[bs, 64, 1024]
        for cross_attention, between_attention in zip(self.cross_attention, self.between_attention):
            out = cross_attention(audio_proj)
            out = between_attention(out) + audio_proj
            # print(f'out shape is: {out.shape}')

        out = self.to_out_adapter(out) #[bs, 77, 768]
        # print(f'context dim is: {out.shape}')
   
        return out