Update model_simple.py
Browse files- model_simple.py +68 -68
model_simple.py
CHANGED
@@ -9,7 +9,7 @@ import numpy as np
|
|
9 |
from datetime import datetime
|
10 |
from dataclasses import dataclass
|
11 |
from torch.nn.functional import scaled_dot_product_attention
|
12 |
-
|
13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
dtype = torch.float32
|
15 |
warnings.filterwarnings("ignore")
|
@@ -31,7 +31,7 @@ class rotary(nn.Module):
|
|
31 |
self.dims = dims
|
32 |
self.head = head
|
33 |
self.head_dim = dims // head
|
34 |
-
self.theta = nn.Parameter((torch.tensor(
|
35 |
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
|
36 |
|
37 |
def _compute_freqs_base(self):
|
@@ -53,6 +53,16 @@ class rotary(nn.Module):
|
|
53 |
x1 = x1.view(orig_shape)
|
54 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def qkv_init(dims: int, head: int):
|
57 |
head_dim = dims // head
|
58 |
q = nn.Linear(dims, dims)
|
@@ -60,60 +70,51 @@ def qkv_init(dims: int, head: int):
|
|
60 |
v = nn.Linear(dims, dims)
|
61 |
o = nn.Linear(dims, dims)
|
62 |
lna = nn.LayerNorm(dims, bias=False)
|
63 |
-
lnb = nn.LayerNorm(
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
head_dim = dims // head
|
68 |
-
scale = head_dim ** -0.25
|
69 |
-
q = q(x) * scale
|
70 |
-
k = k(xa) * scale
|
71 |
-
v = v(xa)
|
72 |
-
batch, ctx, dims = x.shape
|
73 |
-
def _shape(tensor):
|
74 |
-
return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
|
75 |
-
return _shape(q), _shape(k), _shape(v)
|
76 |
|
77 |
-
def calculate_attention(q, k, v, mask=None,
|
78 |
scaled_q = q
|
79 |
-
if
|
80 |
-
scaled_q = q * (1.0 /
|
81 |
-
|
82 |
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
83 |
return out
|
84 |
|
85 |
class LocalOut(nn.Module):
|
86 |
-
def __init__(self,
|
87 |
super().__init__()
|
88 |
-
|
89 |
self.query_module = nn.Linear(head_dim, head_dim)
|
90 |
self.key_module = nn.Linear(head_dim, head_dim)
|
91 |
self.value_module = nn.Linear(head_dim, head_dim)
|
92 |
self.out_proj = nn.Linear(head_dim, head_dim)
|
93 |
-
|
94 |
-
def _reshape_to_output(self, x):
|
95 |
-
return x
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
self.dims = dims
|
102 |
self.head = head
|
103 |
-
self.head_dim = dims // head
|
104 |
self.max_iter = max_iter
|
105 |
self.threshold = nn.Parameter(torch.tensor(threshold))
|
|
|
106 |
self.factor = nn.Parameter(torch.tensor(factor))
|
107 |
-
self.
|
108 |
-
self.lnc = nn.LayerNorm(self.head_dim, bias=False)
|
109 |
-
self.lnd = nn.LayerNorm(self.head_dim, bias=False)
|
110 |
-
self.attn_local = LocalOut(self.head_dim)
|
111 |
|
112 |
def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
|
113 |
-
|
114 |
-
|
|
|
|
|
115 |
|
116 |
iteration = 0
|
|
|
117 |
prev_out = torch.zeros_like(q)
|
118 |
attn_out = torch.zeros_like(q)
|
119 |
threshold = self.threshold.item()
|
@@ -121,7 +122,7 @@ class attentiona(nn.Module):
|
|
121 |
qcur = q
|
122 |
|
123 |
while iteration < self.max_iter:
|
124 |
-
eff_span = min(
|
125 |
if xa is not None:
|
126 |
eff_span = min(eff_span, xa.shape[1])
|
127 |
if eff_span == 0:
|
@@ -130,9 +131,9 @@ class attentiona(nn.Module):
|
|
130 |
qiter = qcur[:, :, :eff_span, :]
|
131 |
kiter = k[:, :, :eff_span, :]
|
132 |
viter = v[:, :, :eff_span, :]
|
133 |
-
q = self.
|
134 |
-
k = self.
|
135 |
-
v = self.
|
136 |
|
137 |
iter_mask = None
|
138 |
if mask is not None:
|
@@ -143,7 +144,7 @@ class attentiona(nn.Module):
|
|
143 |
|
144 |
attn_iter = calculate_attention(
|
145 |
self.lnc(q), self.lnd(k), v,
|
146 |
-
mask=iter_mask)
|
147 |
|
148 |
iter_out = torch.zeros_like(qcur)
|
149 |
iter_out[:, :, :eff_span, :] = attn_iter
|
@@ -157,21 +158,22 @@ class attentiona(nn.Module):
|
|
157 |
qcur = qcur + iter_out
|
158 |
attn_out = iter_out
|
159 |
iteration += 1
|
|
|
160 |
|
161 |
output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
|
162 |
return self.o(output), None
|
163 |
|
164 |
def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None) -> Tensor:
|
165 |
|
166 |
-
batch, ctx, dims = x.
|
167 |
output = torch.zeros_like(x)
|
168 |
num_win = (ctx + win_size - 1) // win_size
|
169 |
|
170 |
for i in range(num_win):
|
171 |
qstart = i * win_size
|
172 |
qend = min(qstart + win_size, ctx)
|
173 |
-
|
174 |
-
if
|
175 |
continue
|
176 |
|
177 |
kstart = max(0, qend - span_len)
|
@@ -186,10 +188,7 @@ class attentiona(nn.Module):
|
|
186 |
elif mask.dim() == 2:
|
187 |
win_mask = mask[qstart:qend, kstart:kend]
|
188 |
|
189 |
-
attn_out, _ = self._focus(
|
190 |
-
x=qwin,
|
191 |
-
xa=kwin,
|
192 |
-
mask=win_mask)
|
193 |
output[:, qstart:qend, :] = attn_out
|
194 |
return output
|
195 |
|
@@ -201,21 +200,21 @@ class attentiona(nn.Module):
|
|
201 |
output, _ = self._focus(x, xa, mask)
|
202 |
return output
|
203 |
|
204 |
-
class
|
205 |
def __init__(self, dims: int, head: int):
|
206 |
-
super(
|
207 |
-
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
208 |
self.dims = dims
|
209 |
self.head = head
|
210 |
-
self.
|
211 |
-
self.rope = rotary(dims=dims, head=head)
|
212 |
-
|
213 |
def forward(self, x: Tensor, xa = None, mask = None):
|
214 |
-
|
215 |
-
|
|
|
|
|
216 |
q = self.rope(q, q.shape[2])
|
217 |
k = self.rope(k, k.shape[2])
|
218 |
-
a = scaled_dot_product_attention(self.
|
219 |
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
220 |
return self.o(out)
|
221 |
|
@@ -224,14 +223,15 @@ class Residual(nn.Module):
|
|
224 |
super().__init__()
|
225 |
|
226 |
self.lna = nn.LayerNorm(dims, bias=False)
|
227 |
-
self.
|
228 |
-
self.
|
229 |
self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
|
230 |
|
231 |
def forward(self, x, xa = None, mask = None) -> Tensor:
|
232 |
-
x = x + self.
|
233 |
if xa is not None:
|
234 |
-
x = x + self.attna(self.lna(x), xa, None
|
|
|
235 |
x = x + self.mlp(self.lna(x))
|
236 |
return x
|
237 |
|
@@ -239,8 +239,9 @@ class processor(nn.Module):
|
|
239 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
240 |
super(processor, self).__init__()
|
241 |
|
242 |
-
self.
|
243 |
-
self.
|
|
|
244 |
self.token_emb = nn.Embedding(vocab, dims)
|
245 |
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
246 |
self.audio_emb = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
@@ -252,7 +253,7 @@ class processor(nn.Module):
|
|
252 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
253 |
|
254 |
self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
|
255 |
-
|
256 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
257 |
self.register_buffer("mask", mask, persistent=False)
|
258 |
|
@@ -263,14 +264,13 @@ class processor(nn.Module):
|
|
263 |
xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
264 |
|
265 |
for b in chain(self.bA or []):
|
266 |
-
xa = b(xa
|
267 |
-
x
|
268 |
-
|
269 |
-
xc = b(torch.cat([x, xa], dim=1), xa=xa, mask=self.mask) if modal else None
|
270 |
x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
|
271 |
|
272 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
273 |
-
x = self.
|
274 |
x = x @ torch.transpose(self.token_emb.weight.to(dtype), 0, 1).float()
|
275 |
return x
|
276 |
|
|
|
9 |
from datetime import datetime
|
10 |
from dataclasses import dataclass
|
11 |
from torch.nn.functional import scaled_dot_product_attention
|
12 |
+
from echoutils import *
|
13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
dtype = torch.float32
|
15 |
warnings.filterwarnings("ignore")
|
|
|
31 |
self.dims = dims
|
32 |
self.head = head
|
33 |
self.head_dim = dims // head
|
34 |
+
self.theta = nn.Parameter((torch.tensor(36000, device=device, dtype=dtype)), requires_grad=True)
|
35 |
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
|
36 |
|
37 |
def _compute_freqs_base(self):
|
|
|
53 |
x1 = x1.view(orig_shape)
|
54 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
55 |
|
56 |
+
def shape(dims, head, q, k, v):
|
57 |
+
head_dim = dims // head
|
58 |
+
scale = head_dim ** -0.25
|
59 |
+
q = q * scale
|
60 |
+
k = k * scale
|
61 |
+
v = v
|
62 |
+
def _shape(tensor):
|
63 |
+
return tensor.view(*tensor.shape[:2], head, -1).permute(0, 2, 1, 3).contiguous()
|
64 |
+
return _shape(q), _shape(k), _shape(v)
|
65 |
+
|
66 |
def qkv_init(dims: int, head: int):
|
67 |
head_dim = dims // head
|
68 |
q = nn.Linear(dims, dims)
|
|
|
70 |
v = nn.Linear(dims, dims)
|
71 |
o = nn.Linear(dims, dims)
|
72 |
lna = nn.LayerNorm(dims, bias=False)
|
73 |
+
lnb = nn.LayerNorm(dims, bias=False)
|
74 |
+
lnc = nn.LayerNorm(head_dim, bias=False)
|
75 |
+
lnd = nn.LayerNorm(head_dim, bias=False)
|
76 |
+
return q, k, v, o, lna, lnb, lnc, lnd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
+
def calculate_attention(q, k, v, mask=None, temp=1.0):
|
79 |
scaled_q = q
|
80 |
+
if temp != 1.0 and temp > 0:
|
81 |
+
scaled_q = q * (1.0 / temp)**.5
|
|
|
82 |
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
83 |
return out
|
84 |
|
85 |
class LocalOut(nn.Module):
|
86 |
+
def __init__(self, dims: int, head: int):
|
87 |
super().__init__()
|
88 |
+
head_dim = dims // head
|
89 |
self.query_module = nn.Linear(head_dim, head_dim)
|
90 |
self.key_module = nn.Linear(head_dim, head_dim)
|
91 |
self.value_module = nn.Linear(head_dim, head_dim)
|
92 |
self.out_proj = nn.Linear(head_dim, head_dim)
|
|
|
|
|
|
|
93 |
|
94 |
+
def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
|
95 |
+
batch, _, ctx, _ = attn_output.shape
|
96 |
+
return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
|
97 |
+
|
98 |
+
class attentionb(nn.Module):
|
99 |
+
def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0):
|
100 |
+
super(attentionb, self).__init__()
|
101 |
+
self.q, self.k, self.v, self.o, self.lna, self.lnb, self.lnc, self.lnd = qkv_init(dims, head)
|
102 |
self.dims = dims
|
103 |
self.head = head
|
|
|
104 |
self.max_iter = max_iter
|
105 |
self.threshold = nn.Parameter(torch.tensor(threshold))
|
106 |
+
self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
|
107 |
self.factor = nn.Parameter(torch.tensor(factor))
|
108 |
+
self.alocal = LocalOut(dims, head)
|
|
|
|
|
|
|
109 |
|
110 |
def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
|
111 |
+
q = self.q(self.lna(x))
|
112 |
+
k = self.k(self.lnb(x if xa is None else xa))
|
113 |
+
v = self.v(self.lnb(x if xa is None else xa))
|
114 |
+
q, k, v = shape(self.dims, self.head, q, k, v)
|
115 |
|
116 |
iteration = 0
|
117 |
+
temp = self.temp.item()
|
118 |
prev_out = torch.zeros_like(q)
|
119 |
attn_out = torch.zeros_like(q)
|
120 |
threshold = self.threshold.item()
|
|
|
122 |
qcur = q
|
123 |
|
124 |
while iteration < self.max_iter:
|
125 |
+
eff_span = min(qcur.shape[1], k.shape[1])
|
126 |
if xa is not None:
|
127 |
eff_span = min(eff_span, xa.shape[1])
|
128 |
if eff_span == 0:
|
|
|
131 |
qiter = qcur[:, :, :eff_span, :]
|
132 |
kiter = k[:, :, :eff_span, :]
|
133 |
viter = v[:, :, :eff_span, :]
|
134 |
+
q = self.alocal.query_module(qiter)
|
135 |
+
k = self.alocal.key_module(kiter)
|
136 |
+
v = self.alocal.value_module(viter)
|
137 |
|
138 |
iter_mask = None
|
139 |
if mask is not None:
|
|
|
144 |
|
145 |
attn_iter = calculate_attention(
|
146 |
self.lnc(q), self.lnd(k), v,
|
147 |
+
mask=iter_mask, temp=temp)
|
148 |
|
149 |
iter_out = torch.zeros_like(qcur)
|
150 |
iter_out[:, :, :eff_span, :] = attn_iter
|
|
|
158 |
qcur = qcur + iter_out
|
159 |
attn_out = iter_out
|
160 |
iteration += 1
|
161 |
+
temp += 0.005
|
162 |
|
163 |
output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
|
164 |
return self.o(output), None
|
165 |
|
166 |
def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None) -> Tensor:
|
167 |
|
168 |
+
batch, ctx, dims = x.shape
|
169 |
output = torch.zeros_like(x)
|
170 |
num_win = (ctx + win_size - 1) // win_size
|
171 |
|
172 |
for i in range(num_win):
|
173 |
qstart = i * win_size
|
174 |
qend = min(qstart + win_size, ctx)
|
175 |
+
win_qlen = qend - qstart
|
176 |
+
if win_qlen == 0:
|
177 |
continue
|
178 |
|
179 |
kstart = max(0, qend - span_len)
|
|
|
188 |
elif mask.dim() == 2:
|
189 |
win_mask = mask[qstart:qend, kstart:kend]
|
190 |
|
191 |
+
attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask)
|
|
|
|
|
|
|
192 |
output[:, qstart:qend, :] = attn_out
|
193 |
return output
|
194 |
|
|
|
200 |
output, _ = self._focus(x, xa, mask)
|
201 |
return output
|
202 |
|
203 |
+
class attentiona(nn.Module):
|
204 |
def __init__(self, dims: int, head: int):
|
205 |
+
super(attentiona, self).__init__()
|
206 |
+
self.q, self.k, self.v, self.o, self.lna, self.lnb, self.lnc, self.lnd = qkv_init(dims, head)
|
207 |
self.dims = dims
|
208 |
self.head = head
|
209 |
+
self.rope = rotary(dims=dims, head=head)
|
|
|
|
|
210 |
def forward(self, x: Tensor, xa = None, mask = None):
|
211 |
+
q = self.q(self.lna(x))
|
212 |
+
k = self.k(self.lnb(x if xa is None else xa))
|
213 |
+
v = self.v(self.lnb(x if xa is None else xa))
|
214 |
+
q, k, v = shape(self.dims, self.head, q, k, v)
|
215 |
q = self.rope(q, q.shape[2])
|
216 |
k = self.rope(k, k.shape[2])
|
217 |
+
a = scaled_dot_product_attention(self.lnc(q), self.lnd(k), v, is_causal=mask is not None and q.shape[1] > 1)
|
218 |
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
219 |
return self.o(out)
|
220 |
|
|
|
223 |
super().__init__()
|
224 |
|
225 |
self.lna = nn.LayerNorm(dims, bias=False)
|
226 |
+
self.attna = attentiona(dims, head)
|
227 |
+
self.attnb = attentionb(dims, head, max_iter=3)
|
228 |
self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
|
229 |
|
230 |
def forward(self, x, xa = None, mask = None) -> Tensor:
|
231 |
+
x = x + self.attna(self.lna(x), mask=mask)
|
232 |
if xa is not None:
|
233 |
+
x = x + self.attna(self.lna(x), xa, mask=None)
|
234 |
+
x = x + self.attnb(self.lna(x), xa, mask=None, use_sliding_win=True, win_size=256, span_len=512)
|
235 |
x = x + self.mlp(self.lna(x))
|
236 |
return x
|
237 |
|
|
|
239 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
240 |
super(processor, self).__init__()
|
241 |
|
242 |
+
self.lna = nn.LayerNorm(dims)
|
243 |
+
self.lnb = nn.LayerNorm(dims)
|
244 |
+
self.lnc = nn.LayerNorm(dims)
|
245 |
self.token_emb = nn.Embedding(vocab, dims)
|
246 |
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
247 |
self.audio_emb = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
|
|
253 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
254 |
|
255 |
self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
|
256 |
+
|
257 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
258 |
self.register_buffer("mask", mask, persistent=False)
|
259 |
|
|
|
264 |
xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
265 |
|
266 |
for b in chain(self.bA or []):
|
267 |
+
xa = b(self.lna(xa))
|
268 |
+
x = b(self.lnb(x), xa=xa, mask=self.mask)
|
269 |
+
xc = b(torch.cat([x, xa], dim=1), xa=None, mask=self.mask) if modal else None
|
|
|
270 |
x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
|
271 |
|
272 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
273 |
+
x = self.lnc(x)
|
274 |
x = x @ torch.transpose(self.token_emb.weight.to(dtype), 0, 1).float()
|
275 |
return x
|
276 |
|