Update model_simple.py
Browse files- model_simple.py +305 -47
model_simple.py
CHANGED
@@ -55,73 +55,225 @@ class rotary(nn.Module):
|
|
55 |
x1 = x1.view(orig_shape)
|
56 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
57 |
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
super(attention, self).__init__()
|
|
|
|
|
61 |
self.dims = dims
|
62 |
self.head = head
|
63 |
self.head_dim = dims // head
|
64 |
-
self.
|
65 |
-
self.
|
66 |
-
self.
|
67 |
-
|
68 |
-
self.
|
69 |
-
self.
|
70 |
-
self.
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def forward(self, x: Tensor, xa = None, mask = None):
|
73 |
-
|
74 |
-
k = self.
|
75 |
-
|
76 |
-
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
77 |
-
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
78 |
-
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
79 |
q = self.rope(q, q.shape[2])
|
80 |
k = self.rope(k, k.shape[2])
|
81 |
-
|
|
|
82 |
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
83 |
return self.o(out)
|
84 |
|
85 |
-
class
|
86 |
-
def __init__(self, dims, num_types=4):
|
87 |
-
super().__init__()
|
88 |
-
self.gates = nn.ModuleList([nn.Sequential(Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)])
|
89 |
-
self.classifier = nn.Sequential(Linear(dims, num_types), nn.Softmax(dim=-1))
|
90 |
-
def forward(self, x):
|
91 |
-
types = self.classifier(x)
|
92 |
-
gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
|
93 |
-
cgate = torch.sum(gates * types.unsqueeze(2), dim=-1)
|
94 |
-
return cgate
|
95 |
-
|
96 |
-
class Residual(nn.Module):
|
97 |
-
_seen = set()
|
98 |
def __init__(self, dims: int, head: int, act: str = "silu"):
|
99 |
super().__init__()
|
100 |
-
|
101 |
-
self.
|
102 |
-
self.
|
|
|
103 |
self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
|
104 |
-
self.tgate = tgate(dims=dims, num_types=4*2)
|
105 |
|
106 |
-
def forward(self, x, xa=None, mask=None) -> Tensor:
|
107 |
-
|
|
|
108 |
if xa is not None:
|
109 |
-
x = x + self.
|
110 |
-
|
111 |
-
x = b * xb + (1 - b) * x
|
112 |
-
out = self.mlp(self.ln(x))
|
113 |
-
gate = self.tgate(self.ln(x))
|
114 |
-
x = x + gate * out
|
115 |
return x
|
116 |
-
|
117 |
class processor(nn.Module):
|
118 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
119 |
super(processor, self).__init__()
|
|
|
120 |
self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
|
121 |
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
122 |
self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
|
123 |
self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
|
124 |
-
self.
|
125 |
|
126 |
act_fn = get_activation(act)
|
127 |
self.encoder = nn.Sequential(
|
@@ -131,24 +283,41 @@ class processor(nn.Module):
|
|
131 |
|
132 |
self.bA = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
133 |
self.bB = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
|
|
134 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
135 |
self.register_buffer("mask", mask, persistent=False)
|
136 |
|
137 |
-
def forward(self, x, xa) -> Tensor:
|
138 |
|
139 |
x = self.token(x.long()) + self.positional[:x.shape[1]]
|
140 |
xa = self.encoder(xa).permute(0, 2, 1)
|
141 |
-
xa = xa + self.
|
|
|
142 |
for b in chain(self.bA or []):
|
143 |
xa = b(x=xa, xa=None, mask=None)
|
|
|
144 |
for b in chain(self.bB or []):
|
145 |
x = b(x=x, xa=None, mask=self.mask)
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
148 |
x = self.ln(x)
|
149 |
x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
150 |
return x
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
class Model(nn.Module):
|
153 |
def __init__(self, param: Dimensions):
|
154 |
super().__init__()
|
@@ -211,3 +380,92 @@ class Model(nn.Module):
|
|
211 |
if count > 0:
|
212 |
print(f"{module_type}: {count}")
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
x1 = x1.view(orig_shape)
|
56 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
57 |
|
58 |
+
def shape(self, tensor: torch.Tensor, ctx: int, batch: int):
|
59 |
+
return tensor.view(batch, ctx, self.head, self.head_dim).transpose(1, 2).contiguous()
|
60 |
+
|
61 |
+
def reshape_to_output(self, attn_output, batch, ctx):
|
62 |
+
return attn_output.permute(0, 2, 1, 3).reshape(batch, ctx, self.dims).contiguous()
|
63 |
+
|
64 |
+
def qkv_init(dims: int, head: int):
|
65 |
+
head_dim = dims // head
|
66 |
+
q = nn.Linear(dims, dims)
|
67 |
+
k = nn.Linear(dims, dims, bias=False)
|
68 |
+
v = nn.Linear(dims, dims)
|
69 |
+
o = nn.Linear(dims, dims)
|
70 |
+
lna = nn.LayerNorm(dims, bias=False)
|
71 |
+
lnb = nn.LayerNorm(head_dim, bias=False)
|
72 |
+
return q, k, v, o, lna, lnb
|
73 |
+
|
74 |
+
def create_qkv(dims, head, q, k, v, x, xa=None):
|
75 |
+
z = default(xa, x)
|
76 |
+
head_dim = dims // head
|
77 |
+
scale = head_dim ** -0.25
|
78 |
+
q = q(x) * scale
|
79 |
+
k = k(z) * scale
|
80 |
+
v = v(z)
|
81 |
+
batch, ctx, dims = q.shape
|
82 |
+
def _shape(tensor):
|
83 |
+
return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
|
84 |
+
return _shape(q), _shape(k), _shape(v)
|
85 |
+
|
86 |
+
def calculate_attention(q, k, v, mask=None, temperature=1.0):
|
87 |
+
batch, head, ctx, dims = q.shape
|
88 |
+
scaled_q = q
|
89 |
+
if temperature != 1.0 and temperature > 0:
|
90 |
+
scaled_q = q * (1.0 / temperature)**.5
|
91 |
+
a = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
92 |
+
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
93 |
+
return out, None
|
94 |
+
|
95 |
+
class LocalAttentionModule(nn.Module):
|
96 |
+
def __init__(self, head_dim: int):
|
97 |
+
super().__init__()
|
98 |
+
self.head_dim = head_dim
|
99 |
+
self.query_module = nn.Linear(head_dim, head_dim)
|
100 |
+
self.key_module = nn.Linear(head_dim, head_dim)
|
101 |
+
self.value_module = nn.Linear(head_dim, head_dim)
|
102 |
+
self.out_proj = nn.Linear(head_dim, head_dim)
|
103 |
+
|
104 |
+
def _reshape_to_output(self, x):
|
105 |
+
return x
|
106 |
+
|
107 |
+
class attentiona(nn.Module):
|
108 |
+
def __init__(self, dims: int, head: int, max_iters: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
|
109 |
super(attention, self).__init__()
|
110 |
+
|
111 |
+
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
112 |
self.dims = dims
|
113 |
self.head = head
|
114 |
self.head_dim = dims // head
|
115 |
+
self.dropout = dropout
|
116 |
+
self.max_iters = max_iters
|
117 |
+
self.rope = rotary(dims=dims, head=head)
|
118 |
+
|
119 |
+
self.threshold = nn.Parameter(torch.tensor(threshold))
|
120 |
+
self.factor = nn.Parameter(torch.tensor(factor))
|
121 |
+
self.attn_local = LocalAttentionModule(self.head_dim)
|
122 |
+
|
123 |
+
def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
|
124 |
+
z = default(xa, x)
|
125 |
+
q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
|
126 |
+
# q=self.lnb(q)
|
127 |
+
# k=self.lnb(k)
|
128 |
+
iteration = 0
|
129 |
+
prev_attn = torch.zeros_like(q)
|
130 |
+
attn_out = torch.zeros_like(q)
|
131 |
+
threshold = self.threshold.item()
|
132 |
+
factor = self.factor.item()
|
133 |
+
|
134 |
+
q_cur = q
|
135 |
+
while iteration < self.max_iters:
|
136 |
+
eff_span = z.shape[1]
|
137 |
+
if eff_span == 0:
|
138 |
+
break
|
139 |
+
|
140 |
+
q_iter = q_cur[:, :, :eff_span, :]
|
141 |
+
k_iter = k[:, :, :eff_span, :]
|
142 |
+
v_iter = v[:, :, :eff_span, :]
|
143 |
+
q = self.attn_local.query_module(q_iter)
|
144 |
+
k = self.attn_local.key_module(k_iter)
|
145 |
+
v = self.attn_local.value_module(v_iter)
|
146 |
+
|
147 |
+
iter_mask = None
|
148 |
+
if mask is not None:
|
149 |
+
if mask.dim() == 4:
|
150 |
+
iter_mask = mask[:, :, :eff_span, :eff_span]
|
151 |
+
elif mask.dim() == 2:
|
152 |
+
iter_mask = mask[:eff_span, :eff_span]
|
153 |
+
|
154 |
+
q = self.rope(q, q.shape[2])
|
155 |
+
k = self.rope(k, k.shape[2])
|
156 |
+
|
157 |
+
attn_iter, _ = calculate_attention(
|
158 |
+
self.lnb(q), self.lnb(k), v, mask=iter_mask)
|
159 |
+
|
160 |
+
out_span = self.attn_local._reshape_to_output(attn_iter)
|
161 |
+
if out_span.dim() == 4:
|
162 |
+
b, h, s, d = out_span.shape
|
163 |
+
proj_span = self.attn_local.out_proj(out_span.view(-1, d)).view(b, h, s, -1)
|
164 |
+
elif out_span.dim() == 3:
|
165 |
+
b, s, d = out_span.shape
|
166 |
+
if d == self.head_dim:
|
167 |
+
proj_span = self.attn_local.out_proj(out_span.view(-1, d)).view(b, 1, s, -1)
|
168 |
+
elif d == self.head * self.head_dim:
|
169 |
+
proj_span = out_span.view(b, self.head, s, self.head_dim)
|
170 |
+
else:
|
171 |
+
raise RuntimeError(f"Cannot reshape out_span of shape {out_span.shape} to [b, h, s, head_dim]")
|
172 |
+
else:
|
173 |
+
raise RuntimeError(f"Unexpected out_span shape: {out_span.shape}")
|
174 |
+
|
175 |
+
iter_out = torch.zeros_like(q_cur)
|
176 |
+
iter_out[:, :, :eff_span, :] = proj_span
|
177 |
+
diff = torch.abs(iter_out - prev_attn).mean()
|
178 |
+
dthresh = threshold + factor * diff
|
179 |
+
if diff < dthresh and iteration > 0:
|
180 |
+
attn_out = iter_out
|
181 |
+
break
|
182 |
+
|
183 |
+
prev_attn = iter_out.clone()
|
184 |
+
q_cur = q_cur + iter_out
|
185 |
+
attn_out = iter_out
|
186 |
+
iteration += 1
|
187 |
+
|
188 |
+
output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
|
189 |
+
return self.o(output), None
|
190 |
+
|
191 |
+
def _slide_win_local(self, x: Tensor, win_size: int, span_len: int,
|
192 |
+
mask: Optional[Tensor] = None) -> Tensor:
|
193 |
+
batch, ctx, dims = x.shape
|
194 |
+
output = torch.zeros_like(x)
|
195 |
+
num_win = (ctx + win_size - 1) // win_size
|
196 |
+
|
197 |
+
for i in range(num_win):
|
198 |
+
q_start = i * win_size
|
199 |
+
q_end = min(q_start + win_size, ctx)
|
200 |
+
q_len = q_end - q_start
|
201 |
+
if q_len == 0:
|
202 |
+
continue
|
203 |
+
|
204 |
+
kv_start = max(0, q_end - span_len)
|
205 |
+
kv_end = q_end
|
206 |
+
query_win = x[:, q_start:q_end, :]
|
207 |
+
key_win = x[:, kv_start:kv_end, :]
|
208 |
+
|
209 |
+
win_mask = None
|
210 |
+
if mask is not None:
|
211 |
+
if mask.dim() == 4:
|
212 |
+
win_mask = mask[:, :, q_start:q_end, kv_start:kv_end]
|
213 |
+
elif mask.dim() == 2:
|
214 |
+
win_mask = mask[q_start:q_end, kv_start:kv_end]
|
215 |
+
|
216 |
+
attn_out_win, _ = self._focus(
|
217 |
+
x=query_win,
|
218 |
+
xa=key_win,
|
219 |
+
mask=win_mask)
|
220 |
+
output[:, q_start:q_end, :] = attn_out_win
|
221 |
+
return output
|
222 |
+
|
223 |
+
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None,
|
224 |
+
use_sliding_window: bool = False, win_size: int = 512, span_len: int = 1024) -> Tensor:
|
225 |
+
if use_sliding_window:
|
226 |
+
return self._slide_win_local(x, win_size, span_len, mask)
|
227 |
+
else:
|
228 |
+
output, _ = self._focus(x, xa, mask)
|
229 |
+
return output
|
230 |
+
|
231 |
+
class attentionb(nn.Module):
|
232 |
+
def __init__(self, dims: int, head: int):
|
233 |
+
super(attentionb, self).__init__()
|
234 |
+
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
235 |
+
self.dims = dims
|
236 |
+
self.head = head
|
237 |
+
self.head_dim = dims // head
|
238 |
+
self.rope = rotary(dims=dims, head=head)
|
239 |
+
|
240 |
def forward(self, x: Tensor, xa = None, mask = None):
|
241 |
+
z = default(xa, x)
|
242 |
+
q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
|
243 |
+
|
|
|
|
|
|
|
244 |
q = self.rope(q, q.shape[2])
|
245 |
k = self.rope(k, k.shape[2])
|
246 |
+
|
247 |
+
a = scaled_dot_product_attention(self.lnb(q), self.lnb(k), v, is_causal=mask is not None and q.shape[1] > 1)
|
248 |
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
249 |
return self.o(out)
|
250 |
|
251 |
+
class Residual(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
def __init__(self, dims: int, head: int, act: str = "silu"):
|
253 |
super().__init__()
|
254 |
+
|
255 |
+
self.lna = nn.LayerNorm(dims, bias=False)
|
256 |
+
self.attnb = attentionb(dims, head)
|
257 |
+
self.attna = attentiona(dims, head, max_iters=3)
|
258 |
self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
|
|
|
259 |
|
260 |
+
def forward(self, x, xa = None, mask = None) -> Tensor:
|
261 |
+
|
262 |
+
x = x + self.attnb(self.lna(x), xa=None, mask=mask)
|
263 |
if xa is not None:
|
264 |
+
x = x + self.attna(self.lna(x), xa, mask=None, use_sliding_window=True, win_size=500, span_len=1500)
|
265 |
+
x = x + self.mlp(self.lna(x))
|
|
|
|
|
|
|
|
|
266 |
return x
|
267 |
+
|
268 |
class processor(nn.Module):
|
269 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
270 |
super(processor, self).__init__()
|
271 |
+
|
272 |
self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
|
273 |
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
274 |
self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
|
275 |
self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
|
276 |
+
self.posin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
277 |
|
278 |
act_fn = get_activation(act)
|
279 |
self.encoder = nn.Sequential(
|
|
|
283 |
|
284 |
self.bA = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
285 |
self.bB = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
286 |
+
|
287 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
288 |
self.register_buffer("mask", mask, persistent=False)
|
289 |
|
290 |
+
def forward(self, x, xa, sequential=False) -> Tensor:
|
291 |
|
292 |
x = self.token(x.long()) + self.positional[:x.shape[1]]
|
293 |
xa = self.encoder(xa).permute(0, 2, 1)
|
294 |
+
xa = xa + self.posin(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
295 |
+
|
296 |
for b in chain(self.bA or []):
|
297 |
xa = b(x=xa, xa=None, mask=None)
|
298 |
+
|
299 |
for b in chain(self.bB or []):
|
300 |
x = b(x=x, xa=None, mask=self.mask)
|
301 |
+
y = b(x, xa=xa, mask=None)
|
302 |
+
if sequential:
|
303 |
+
x = y
|
304 |
+
else:
|
305 |
+
a = torch.sigmoid(self.blend)
|
306 |
+
x = a * y + (1 - a) * x
|
307 |
+
|
308 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
309 |
x = self.ln(x)
|
310 |
x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
311 |
return x
|
312 |
|
313 |
+
def init_weights(self):
|
314 |
+
print("Initializing model weights...")
|
315 |
+
self.apply(self._init_weights)
|
316 |
+
print("Initialization summary:")
|
317 |
+
for module_type, count in self.init_counts.items():
|
318 |
+
if count > 0:
|
319 |
+
print(f"{module_type}: {count}")
|
320 |
+
|
321 |
class Model(nn.Module):
|
322 |
def __init__(self, param: Dimensions):
|
323 |
super().__init__()
|
|
|
380 |
if count > 0:
|
381 |
print(f"{module_type}: {count}")
|
382 |
|
383 |
+
def main():
|
384 |
+
token = ""
|
385 |
+
log_dir = os.path.join('D:/newmodel/output/logs/', datetime.now().strftime('%m-%d_%H_%M_%S'))
|
386 |
+
os.makedirs(log_dir, exist_ok=True)
|
387 |
+
tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
|
388 |
+
|
389 |
+
extract_args = {
|
390 |
+
"waveform": False,
|
391 |
+
"spec": False,
|
392 |
+
"f0": False,
|
393 |
+
"f0t": False,
|
394 |
+
"pitch": True,
|
395 |
+
"harmonics": False,
|
396 |
+
"aperiodics": False,
|
397 |
+
"phase_mod": False,
|
398 |
+
"crepe": False,
|
399 |
+
"sample_rate": 16000,
|
400 |
+
"hop_length": 256,
|
401 |
+
"mode": "mean",
|
402 |
+
"debug": False,
|
403 |
+
}
|
404 |
+
|
405 |
+
param = Dimensions(
|
406 |
+
vocab=40000,
|
407 |
+
mels=128,
|
408 |
+
ctx=2048,
|
409 |
+
dims=512,
|
410 |
+
head=4,
|
411 |
+
layer=4,
|
412 |
+
act="swish",
|
413 |
+
)
|
414 |
+
|
415 |
+
train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
|
416 |
+
load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
|
417 |
+
|
418 |
+
model = Model(param).to('cuda')
|
419 |
+
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
420 |
+
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
421 |
+
|
422 |
+
from functools import partial
|
423 |
+
metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
|
424 |
+
|
425 |
+
training_args = Seq2SeqTrainingArguments(
|
426 |
+
output_dir=log_dir,
|
427 |
+
per_device_train_batch_size=1,
|
428 |
+
per_device_eval_batch_size=1,
|
429 |
+
max_steps=1000,
|
430 |
+
eval_steps=100,
|
431 |
+
save_steps=1000,
|
432 |
+
warmup_steps=100,
|
433 |
+
logging_steps=10,
|
434 |
+
logging_dir=log_dir,
|
435 |
+
logging_strategy="steps",
|
436 |
+
eval_strategy="steps",
|
437 |
+
save_strategy="no",
|
438 |
+
report_to=["tensorboard"],
|
439 |
+
push_to_hub=False,
|
440 |
+
save_total_limit=1,
|
441 |
+
label_names=["labels"],
|
442 |
+
save_safetensors=False,
|
443 |
+
eval_on_start=False,
|
444 |
+
batch_eval_metrics=False,
|
445 |
+
disable_tqdm=False,
|
446 |
+
include_tokens_per_second=True,
|
447 |
+
include_num_input_tokens_seen=True,
|
448 |
+
learning_rate=0.00025,
|
449 |
+
weight_decay=0.025,
|
450 |
+
)
|
451 |
+
|
452 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
|
453 |
+
amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
|
454 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
|
455 |
+
|
456 |
+
trainer = Seq2SeqTrainer(
|
457 |
+
args=training_args,
|
458 |
+
model=model,
|
459 |
+
train_dataset=train_dataset,
|
460 |
+
eval_dataset=test_dataset,
|
461 |
+
data_collator=DataCollator(tokenizer=tokenizer),
|
462 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
463 |
+
compute_metrics=metrics_fn,
|
464 |
+
optimizers=(optimizer, scheduler)
|
465 |
+
)
|
466 |
+
|
467 |
+
model.init_weights()
|
468 |
+
trainer.train()
|
469 |
+
if __name__ == "__main__":
|
470 |
+
|
471 |
+
main()
|