Update model_simple.py
Browse files- model_simple.py +196 -116
model_simple.py
CHANGED
@@ -11,7 +11,7 @@ from dataclasses import dataclass
|
|
11 |
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
12 |
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
13 |
from torch.nn.functional import scaled_dot_product_attention
|
14 |
-
|
15 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
dtype = torch.float32
|
17 |
warnings.filterwarnings("ignore")
|
@@ -55,42 +55,34 @@ class rotary(nn.Module):
|
|
55 |
x1 = x1.view(orig_shape)
|
56 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
57 |
|
58 |
-
def
|
59 |
-
return tensor.view(batch, ctx, self.head, self.head_dim).transpose(1, 2).contiguous()
|
60 |
-
|
61 |
-
def reshape_to_output(self, attn_output, batch, ctx):
|
62 |
-
return attn_output.permute(0, 2, 1, 3).reshape(batch, ctx, self.dims).contiguous()
|
63 |
-
|
64 |
-
def qkv_init(dims: int, head: int):
|
65 |
head_dim = dims // head
|
|
|
66 |
q = nn.Linear(dims, dims)
|
67 |
k = nn.Linear(dims, dims, bias=False)
|
68 |
v = nn.Linear(dims, dims)
|
69 |
o = nn.Linear(dims, dims)
|
70 |
-
|
71 |
-
lnb = nn.LayerNorm(head_dim, bias=False)
|
72 |
-
return q, k, v, o, lna, lnb
|
73 |
|
74 |
-
def create_qkv(dims, head, q, k, v, x, xa
|
75 |
-
z = default(xa, x)
|
76 |
head_dim = dims // head
|
77 |
scale = head_dim ** -0.25
|
78 |
q = q(x) * scale
|
79 |
-
k = k(
|
80 |
-
v = v(
|
81 |
-
batch, ctx, dims =
|
82 |
def _shape(tensor):
|
83 |
return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
|
84 |
return _shape(q), _shape(k), _shape(v)
|
85 |
|
86 |
-
def calculate_attention(q, k, v, mask=None, temperature=1.0):
|
87 |
-
batch, head, ctx, dims = q.shape
|
88 |
scaled_q = q
|
89 |
if temperature != 1.0 and temperature > 0:
|
90 |
scaled_q = q * (1.0 / temperature)**.5
|
91 |
-
|
92 |
-
out =
|
93 |
-
|
|
|
94 |
|
95 |
class LocalAttentionModule(nn.Module):
|
96 |
def __init__(self, head_dim: int):
|
@@ -105,43 +97,56 @@ class LocalAttentionModule(nn.Module):
|
|
105 |
return x
|
106 |
|
107 |
class attentiona(nn.Module):
|
108 |
-
def __init__(self, dims: int, head: int,
|
109 |
super(attentiona, self).__init__()
|
110 |
-
|
111 |
-
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
112 |
self.dims = dims
|
113 |
self.head = head
|
114 |
self.head_dim = dims // head
|
115 |
-
self.
|
116 |
-
self.max_iters = max_iters
|
117 |
-
self.rope = rotary(dims=dims, head=head)
|
118 |
-
|
119 |
self.threshold = nn.Parameter(torch.tensor(threshold))
|
120 |
self.factor = nn.Parameter(torch.tensor(factor))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
self.attn_local = LocalAttentionModule(self.head_dim)
|
122 |
|
123 |
def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
127 |
iteration = 0
|
128 |
-
|
129 |
-
attn_out = torch.zeros_like(
|
130 |
threshold = self.threshold.item()
|
131 |
factor = self.factor.item()
|
|
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
136 |
if eff_span == 0:
|
137 |
break
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
q = self.attn_local.query_module(
|
143 |
-
k = self.attn_local.key_module(
|
144 |
-
v = self.attn_local.value_module(
|
145 |
|
146 |
iter_mask = None
|
147 |
if mask is not None:
|
@@ -150,78 +155,63 @@ class attentiona(nn.Module):
|
|
150 |
elif mask.dim() == 2:
|
151 |
iter_mask = mask[:eff_span, :eff_span]
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
b, h, s, d = out_span.shape
|
162 |
-
proj_span = self.attn_local.out_proj(out_span.view(-1, d)).view(b, h, s, -1)
|
163 |
-
elif out_span.dim() == 3:
|
164 |
-
b, s, d = out_span.shape
|
165 |
-
if d == self.head_dim:
|
166 |
-
proj_span = self.attn_local.out_proj(out_span.view(-1, d)).view(b, 1, s, -1)
|
167 |
-
elif d == self.head * self.head_dim:
|
168 |
-
proj_span = out_span.view(b, self.head, s, self.head_dim)
|
169 |
-
else:
|
170 |
-
raise RuntimeError(f"Cannot reshape out_span of shape {out_span.shape} to [b, h, s, head_dim]")
|
171 |
-
else:
|
172 |
-
raise RuntimeError(f"Unexpected out_span shape: {out_span.shape}")
|
173 |
-
|
174 |
-
iter_out = torch.zeros_like(q_cur)
|
175 |
-
iter_out[:, :, :eff_span, :] = proj_span
|
176 |
-
diff = torch.abs(iter_out - prev_attn).mean()
|
177 |
dthresh = threshold + factor * diff
|
|
|
178 |
if diff < dthresh and iteration > 0:
|
179 |
attn_out = iter_out
|
180 |
break
|
181 |
|
182 |
-
|
183 |
-
|
184 |
attn_out = iter_out
|
185 |
iteration += 1
|
186 |
|
187 |
output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
|
188 |
return self.o(output), None
|
189 |
|
190 |
-
def _slide_win_local(self, x: Tensor, win_size: int, span_len: int,
|
191 |
-
|
192 |
-
batch, ctx, dims = x.
|
193 |
output = torch.zeros_like(x)
|
194 |
num_win = (ctx + win_size - 1) // win_size
|
195 |
|
196 |
for i in range(num_win):
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
if
|
201 |
continue
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
|
208 |
win_mask = None
|
209 |
if mask is not None:
|
210 |
if mask.dim() == 4:
|
211 |
-
win_mask = mask[:, :,
|
212 |
elif mask.dim() == 2:
|
213 |
-
win_mask = mask[
|
214 |
|
215 |
-
|
216 |
-
x=
|
217 |
-
xa=
|
218 |
mask=win_mask)
|
219 |
-
output[:,
|
220 |
return output
|
221 |
|
222 |
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None,
|
223 |
-
|
224 |
-
if
|
225 |
return self._slide_win_local(x, win_size, span_len, mask)
|
226 |
else:
|
227 |
output, _ = self._focus(x, xa, mask)
|
@@ -230,7 +220,6 @@ class attentiona(nn.Module):
|
|
230 |
class attentionb(nn.Module):
|
231 |
def __init__(self, dims: int, head: int):
|
232 |
super(attentionb, self).__init__()
|
233 |
-
|
234 |
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
235 |
self.dims = dims
|
236 |
self.head = head
|
@@ -240,10 +229,8 @@ class attentionb(nn.Module):
|
|
240 |
def forward(self, x: Tensor, xa = None, mask = None):
|
241 |
z = default(xa, x)
|
242 |
q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
|
243 |
-
|
244 |
q = self.rope(q, q.shape[2])
|
245 |
k = self.rope(k, k.shape[2])
|
246 |
-
|
247 |
a = scaled_dot_product_attention(self.lnb(q), self.lnb(k), v, is_causal=mask is not None and q.shape[1] > 1)
|
248 |
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
249 |
return self.o(out)
|
@@ -254,56 +241,61 @@ class Residual(nn.Module):
|
|
254 |
|
255 |
self.lna = nn.LayerNorm(dims, bias=False)
|
256 |
self.attnb = attentionb(dims, head)
|
257 |
-
self.attna = attentiona(dims, head,
|
258 |
self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
|
259 |
|
260 |
-
def forward(self, x, xa = None, mask = None) -> Tensor:
|
261 |
-
|
262 |
-
x = x + self.attnb(self.lna(x), xa=None, mask=mask)
|
263 |
if xa is not None:
|
264 |
-
x = x + self.attna(self.lna(x), xa, mask=None,
|
265 |
x = x + self.mlp(self.lna(x))
|
266 |
return x
|
267 |
-
|
268 |
class processor(nn.Module):
|
269 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
270 |
super(processor, self).__init__()
|
271 |
|
272 |
-
self.
|
273 |
-
self.
|
274 |
-
self.
|
|
|
|
|
275 |
|
276 |
act_fn = get_activation(act)
|
277 |
-
self.
|
278 |
Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
279 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
280 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
281 |
|
282 |
-
self.bA = nn.ModuleList([Residual(dims
|
283 |
-
self.bB = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
284 |
|
285 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
286 |
self.register_buffer("mask", mask, persistent=False)
|
287 |
-
self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
|
288 |
|
289 |
def forward(self, x, xa, sequential=False) -> Tensor:
|
290 |
|
291 |
-
x
|
292 |
-
xa = self.
|
293 |
-
xa = xa + self.
|
294 |
|
295 |
for b in chain(self.bA or []):
|
296 |
xa = b(x=xa, xa=None, mask=None)
|
|
|
|
|
|
|
|
|
297 |
|
298 |
-
|
299 |
-
x =
|
300 |
-
|
|
|
|
|
301 |
|
302 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
303 |
x = self.ln(x)
|
304 |
-
x = x @ torch.transpose(self.
|
305 |
return x
|
306 |
-
|
307 |
def init_weights(self):
|
308 |
print("Initializing model weights...")
|
309 |
self.apply(self._init_weights)
|
@@ -338,7 +330,7 @@ class Model(nn.Module):
|
|
338 |
def _init_weights(self, module):
|
339 |
self.init_counts = {
|
340 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
341 |
-
"Conv2d": 0, "processor": 0, "
|
342 |
for name, module in self.named_modules():
|
343 |
if isinstance(module, RMSNorm):
|
344 |
nn.init.ones_(module.weight)
|
@@ -359,10 +351,9 @@ class Model(nn.Module):
|
|
359 |
if module.bias is not None:
|
360 |
nn.init.zeros_(module.bias)
|
361 |
self.init_counts["Conv2d"] += 1
|
362 |
-
elif isinstance(module,
|
363 |
-
self.init_counts["
|
364 |
-
elif isinstance(module,
|
365 |
-
self.init_counts["attentionb"] += 1
|
366 |
self.init_counts["processor"] += 1
|
367 |
|
368 |
def init_weights(self):
|
@@ -373,3 +364,92 @@ class Model(nn.Module):
|
|
373 |
if count > 0:
|
374 |
print(f"{module_type}: {count}")
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
12 |
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
13 |
from torch.nn.functional import scaled_dot_product_attention
|
14 |
+
from echoutils import *
|
15 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
dtype = torch.float32
|
17 |
warnings.filterwarnings("ignore")
|
|
|
55 |
x1 = x1.view(orig_shape)
|
56 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
57 |
|
58 |
+
def qkvinit(dims: int, head: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
head_dim = dims // head
|
60 |
+
scale = head_dim ** -0.5
|
61 |
q = nn.Linear(dims, dims)
|
62 |
k = nn.Linear(dims, dims, bias=False)
|
63 |
v = nn.Linear(dims, dims)
|
64 |
o = nn.Linear(dims, dims)
|
65 |
+
return q, k, v, o, scale
|
|
|
|
|
66 |
|
67 |
+
def create_qkv(dims, head, q, k, v, x, xa):
|
|
|
68 |
head_dim = dims // head
|
69 |
scale = head_dim ** -0.25
|
70 |
q = q(x) * scale
|
71 |
+
k = k(xa) * scale
|
72 |
+
v = v(xa)
|
73 |
+
batch, ctx, dims = x.shape
|
74 |
def _shape(tensor):
|
75 |
return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
|
76 |
return _shape(q), _shape(k), _shape(v)
|
77 |
|
78 |
+
def calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True):
|
|
|
79 |
scaled_q = q
|
80 |
if temperature != 1.0 and temperature > 0:
|
81 |
scaled_q = q * (1.0 / temperature)**.5
|
82 |
+
|
83 |
+
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
84 |
+
# out = scaled_dot_product_attention(scaled_q, k, v, attn_mask=attn_mask, is_causal=is_causal if attn_mask is None else False)
|
85 |
+
return out
|
86 |
|
87 |
class LocalAttentionModule(nn.Module):
|
88 |
def __init__(self, head_dim: int):
|
|
|
97 |
return x
|
98 |
|
99 |
class attentiona(nn.Module):
|
100 |
+
def __init__(self, dims: int, head: int, max_iterations: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
|
101 |
super(attentiona, self).__init__()
|
102 |
+
# self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
|
|
103 |
self.dims = dims
|
104 |
self.head = head
|
105 |
self.head_dim = dims // head
|
106 |
+
self.max_iterations = max_iterations
|
|
|
|
|
|
|
107 |
self.threshold = nn.Parameter(torch.tensor(threshold))
|
108 |
self.factor = nn.Parameter(torch.tensor(factor))
|
109 |
+
self.dropout = dropout
|
110 |
+
|
111 |
+
self.q = nn.Linear(dims, dims)
|
112 |
+
self.k = nn.Linear(dims, dims, bias=False)
|
113 |
+
self.v = nn.Linear(dims, dims)
|
114 |
+
self.o = nn.Linear(dims, dims)
|
115 |
+
|
116 |
+
self.lna = nn.LayerNorm(dims, bias=False)
|
117 |
+
self.lnb = nn.LayerNorm(dims, bias=False)
|
118 |
+
self.lnc = nn.LayerNorm(self.head_dim, bias=False)
|
119 |
+
self.lnd = nn.LayerNorm(self.head_dim, bias=False)
|
120 |
self.attn_local = LocalAttentionModule(self.head_dim)
|
121 |
|
122 |
def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
|
123 |
+
q = self.q(self.lna(x))
|
124 |
+
k = self.k(self.lnb(x if xa is None else xa))
|
125 |
+
v = self.v(self.lnb(x if xa is None else xa))
|
126 |
+
query = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
127 |
+
key = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
128 |
+
value = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
129 |
+
|
130 |
iteration = 0
|
131 |
+
prev_out = torch.zeros_like(query)
|
132 |
+
attn_out = torch.zeros_like(query)
|
133 |
threshold = self.threshold.item()
|
134 |
factor = self.factor.item()
|
135 |
+
qcur = query
|
136 |
|
137 |
+
while iteration < self.max_iterations:
|
138 |
+
eff_span = min(x.shape[1], qcur.size(1), key.size(1))
|
139 |
+
if xa is not None:
|
140 |
+
eff_span = min(eff_span, xa.shape[1])
|
141 |
if eff_span == 0:
|
142 |
break
|
143 |
|
144 |
+
qiter = qcur[:, :, :eff_span, :]
|
145 |
+
kiter = key[:, :, :eff_span, :]
|
146 |
+
viter = value[:, :, :eff_span, :]
|
147 |
+
q = self.attn_local.query_module(qiter)
|
148 |
+
k = self.attn_local.key_module(kiter)
|
149 |
+
v = self.attn_local.value_module(viter)
|
150 |
|
151 |
iter_mask = None
|
152 |
if mask is not None:
|
|
|
155 |
elif mask.dim() == 2:
|
156 |
iter_mask = mask[:eff_span, :eff_span]
|
157 |
|
158 |
+
attn_iter = calculate_attention(
|
159 |
+
self.lnc(q), self.lnd(k), v,
|
160 |
+
mask=iter_mask,
|
161 |
+
is_causal=True)
|
162 |
+
|
163 |
+
iter_out = torch.zeros_like(qcur)
|
164 |
+
iter_out[:, :, :eff_span, :] = attn_iter
|
165 |
+
diff = torch.abs(iter_out - prev_out).mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
dthresh = threshold + factor * diff
|
167 |
+
|
168 |
if diff < dthresh and iteration > 0:
|
169 |
attn_out = iter_out
|
170 |
break
|
171 |
|
172 |
+
prev_out = iter_out.clone()
|
173 |
+
qcur = qcur + iter_out
|
174 |
attn_out = iter_out
|
175 |
iteration += 1
|
176 |
|
177 |
output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
|
178 |
return self.o(output), None
|
179 |
|
180 |
+
def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None, is_causal: bool = False) -> Tensor:
|
181 |
+
|
182 |
+
batch, ctx, dims = x.size()
|
183 |
output = torch.zeros_like(x)
|
184 |
num_win = (ctx + win_size - 1) // win_size
|
185 |
|
186 |
for i in range(num_win):
|
187 |
+
qstart = i * win_size
|
188 |
+
qend = min(qstart + win_size, ctx)
|
189 |
+
current_win_qlen = qend - qstart
|
190 |
+
if current_win_qlen == 0:
|
191 |
continue
|
192 |
|
193 |
+
kvstart = max(0, qend - span_len)
|
194 |
+
kvend = qend
|
195 |
+
qwin = x[:, qstart:qend, :]
|
196 |
+
kwin = x[:, kvstart:kvend, :]
|
197 |
|
198 |
win_mask = None
|
199 |
if mask is not None:
|
200 |
if mask.dim() == 4:
|
201 |
+
win_mask = mask[:, :, qstart:qend, kvstart:kvend]
|
202 |
elif mask.dim() == 2:
|
203 |
+
win_mask = mask[qstart:qend, kvstart:kvend]
|
204 |
|
205 |
+
attn_out, _ = self._focus(
|
206 |
+
x=qwin,
|
207 |
+
xa=kwin,
|
208 |
mask=win_mask)
|
209 |
+
output[:, qstart:qend, :] = attn_out
|
210 |
return output
|
211 |
|
212 |
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None,
|
213 |
+
use_sliding_win: bool = False, win_size: int = 512, span_len: int = 1024) -> Tensor:
|
214 |
+
if use_sliding_win:
|
215 |
return self._slide_win_local(x, win_size, span_len, mask)
|
216 |
else:
|
217 |
output, _ = self._focus(x, xa, mask)
|
|
|
220 |
class attentionb(nn.Module):
|
221 |
def __init__(self, dims: int, head: int):
|
222 |
super(attentionb, self).__init__()
|
|
|
223 |
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
224 |
self.dims = dims
|
225 |
self.head = head
|
|
|
229 |
def forward(self, x: Tensor, xa = None, mask = None):
|
230 |
z = default(xa, x)
|
231 |
q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
|
|
|
232 |
q = self.rope(q, q.shape[2])
|
233 |
k = self.rope(k, k.shape[2])
|
|
|
234 |
a = scaled_dot_product_attention(self.lnb(q), self.lnb(k), v, is_causal=mask is not None and q.shape[1] > 1)
|
235 |
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
236 |
return self.o(out)
|
|
|
241 |
|
242 |
self.lna = nn.LayerNorm(dims, bias=False)
|
243 |
self.attnb = attentionb(dims, head)
|
244 |
+
self.attna = attentiona(dims, head, max_iterations=3)
|
245 |
self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
|
246 |
|
247 |
+
def forward(self, x, xa = None, mask = None) -> Tensor:
|
248 |
+
x = x + self.attnb(self.lna(x), xa=None, mask=mask)
|
|
|
249 |
if xa is not None:
|
250 |
+
x = x + self.attna(self.lna(x), xa, mask=None, use_sliding_win=True, win_size=500, span_len=1500)
|
251 |
x = x + self.mlp(self.lna(x))
|
252 |
return x
|
253 |
+
|
254 |
class processor(nn.Module):
|
255 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
256 |
super(processor, self).__init__()
|
257 |
|
258 |
+
self.ln = nn.LayerNorm(dims)
|
259 |
+
self.blend = nn.Parameter(torch.tensor(0.5), requires_grad=True)
|
260 |
+
self.token_emb = nn.Embedding(vocab, dims)
|
261 |
+
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
262 |
+
self.audio_emb = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
263 |
|
264 |
act_fn = get_activation(act)
|
265 |
+
self.audio_enc = nn.Sequential(
|
266 |
Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
267 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
268 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
269 |
|
270 |
+
self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
|
|
|
271 |
|
272 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
273 |
self.register_buffer("mask", mask, persistent=False)
|
|
|
274 |
|
275 |
def forward(self, x, xa, sequential=False) -> Tensor:
|
276 |
|
277 |
+
x = self.token_emb(x.long()) + self.positions[:x.shape[1]]
|
278 |
+
xa = self.audio_enc(xa).permute(0, 2, 1)
|
279 |
+
xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
280 |
|
281 |
for b in chain(self.bA or []):
|
282 |
xa = b(x=xa, xa=None, mask=None)
|
283 |
+
x = b(x=x, xa=None, mask=self.mask)
|
284 |
+
x = b(x=x, xa=xa, mask=None)
|
285 |
+
# xc = b(torch.cat([x, xa], dim=1), xa=None, mask=self.mask)
|
286 |
+
# x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None)
|
287 |
|
288 |
+
# if sequential:
|
289 |
+
# x = y
|
290 |
+
# else:
|
291 |
+
# a = torch.sigmoid(self.blend)
|
292 |
+
# x = a * y + (1 - a) * x
|
293 |
|
294 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
295 |
x = self.ln(x)
|
296 |
+
x = x @ torch.transpose(self.token_emb.weight.to(dtype), 0, 1).float()
|
297 |
return x
|
298 |
+
|
299 |
def init_weights(self):
|
300 |
print("Initializing model weights...")
|
301 |
self.apply(self._init_weights)
|
|
|
330 |
def _init_weights(self, module):
|
331 |
self.init_counts = {
|
332 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
333 |
+
"Conv2d": 0, "processor": 0, "attention": 0, "Residual": 0}
|
334 |
for name, module in self.named_modules():
|
335 |
if isinstance(module, RMSNorm):
|
336 |
nn.init.ones_(module.weight)
|
|
|
351 |
if module.bias is not None:
|
352 |
nn.init.zeros_(module.bias)
|
353 |
self.init_counts["Conv2d"] += 1
|
354 |
+
elif isinstance(module, Residual):
|
355 |
+
self.init_counts["Residual"] += 1
|
356 |
+
elif isinstance(module, processor):
|
|
|
357 |
self.init_counts["processor"] += 1
|
358 |
|
359 |
def init_weights(self):
|
|
|
364 |
if count > 0:
|
365 |
print(f"{module_type}: {count}")
|
366 |
|
367 |
+
def main():
|
368 |
+
token = ""
|
369 |
+
log_dir = os.path.join('D:/newmodel/output/logs/', datetime.now().strftime('%m-%d_%H_%M_%S'))
|
370 |
+
os.makedirs(log_dir, exist_ok=True)
|
371 |
+
tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
|
372 |
+
|
373 |
+
extract_args = {
|
374 |
+
"waveform": False,
|
375 |
+
"spec": False,
|
376 |
+
"f0": False,
|
377 |
+
"f0t": False,
|
378 |
+
"pitch": True,
|
379 |
+
"harmonics": False,
|
380 |
+
"aperiodics": False,
|
381 |
+
"phase_mod": False,
|
382 |
+
"crepe": False,
|
383 |
+
"sample_rate": 16000,
|
384 |
+
"hop_length": 256,
|
385 |
+
"mode": "mean",
|
386 |
+
"debug": False,
|
387 |
+
}
|
388 |
+
|
389 |
+
param = Dimensions(
|
390 |
+
vocab=40000,
|
391 |
+
mels=128,
|
392 |
+
ctx=2048,
|
393 |
+
dims=512,
|
394 |
+
head=4,
|
395 |
+
layer=4,
|
396 |
+
act="swish",
|
397 |
+
)
|
398 |
+
|
399 |
+
train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
|
400 |
+
load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
|
401 |
+
|
402 |
+
model = Model(param).to('cuda')
|
403 |
+
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
404 |
+
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
405 |
+
|
406 |
+
from functools import partial
|
407 |
+
metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
|
408 |
+
|
409 |
+
training_args = Seq2SeqTrainingArguments(
|
410 |
+
output_dir=log_dir,
|
411 |
+
per_device_train_batch_size=1,
|
412 |
+
per_device_eval_batch_size=1,
|
413 |
+
max_steps=1000,
|
414 |
+
eval_steps=100,
|
415 |
+
save_steps=1000,
|
416 |
+
warmup_steps=100,
|
417 |
+
logging_steps=10,
|
418 |
+
logging_dir=log_dir,
|
419 |
+
logging_strategy="steps",
|
420 |
+
eval_strategy="steps",
|
421 |
+
save_strategy="no",
|
422 |
+
report_to=["tensorboard"],
|
423 |
+
push_to_hub=False,
|
424 |
+
save_total_limit=1,
|
425 |
+
label_names=["labels"],
|
426 |
+
save_safetensors=False,
|
427 |
+
eval_on_start=False,
|
428 |
+
batch_eval_metrics=False,
|
429 |
+
disable_tqdm=False,
|
430 |
+
include_tokens_per_second=True,
|
431 |
+
include_num_input_tokens_seen=True,
|
432 |
+
learning_rate=0.00025,
|
433 |
+
weight_decay=0.025,
|
434 |
+
)
|
435 |
+
|
436 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
|
437 |
+
amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
|
438 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
|
439 |
+
|
440 |
+
trainer = Seq2SeqTrainer(
|
441 |
+
args=training_args,
|
442 |
+
model=model,
|
443 |
+
train_dataset=train_dataset,
|
444 |
+
eval_dataset=test_dataset,
|
445 |
+
data_collator=DataCollator(tokenizer=tokenizer),
|
446 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
447 |
+
compute_metrics=metrics_fn,
|
448 |
+
optimizers=(optimizer, scheduler)
|
449 |
+
)
|
450 |
+
|
451 |
+
model.init_weights()
|
452 |
+
trainer.train()
|
453 |
+
if __name__ == "__main__":
|
454 |
+
|
455 |
+
main()
|