Sin2pi commited on
Commit
7cc3b2c
·
verified ·
1 Parent(s): 67c6d40

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +65 -169
model_simple.py CHANGED
@@ -35,34 +35,24 @@ class rotary(nn.Module):
35
  self.dims = dims
36
  self.head = head
37
  self.head_dim = dims // head
38
- self.theta = nn.Parameter((torch.tensor(36000, device=device, dtype=dtype)), requires_grad=True)
39
- self.twotwenty = nn.Parameter((torch.tensor(220, device=device, dtype=dtype)), requires_grad=True)
40
-
41
- def forward(self, x=None) -> Tensor:
42
- freqs = (self.theta / 220.0) * 700 * (
43
- torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
44
- self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
45
- t = torch.arange(x, device=device, dtype=dtype) # type: ignore
46
  freqs = t[:, None] * freqs
47
  freqs=torch.polar(torch.ones_like(freqs), freqs)
48
- return freqs.unsqueeze(0)
49
-
50
- @staticmethod
51
- def apply_rotary(x, freqs):
52
  x1 = x[..., :freqs.shape[-1]*2]
53
  x2 = x[..., freqs.shape[-1]*2:]
54
  orig_shape = x1.shape
55
- if x1.ndim == 2:
56
- x1 = x1.unsqueeze(0)
57
  x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
58
  x1 = torch.view_as_complex(x1) * freqs
59
  x1 = torch.view_as_real(x1).flatten(-2)
60
  x1 = x1.view(orig_shape)
61
  return torch.cat([x1.type_as(x), x2], dim=-1)
62
 
63
- class MultiheadA(nn.Module):
64
  def __init__(self, dims: int, head: int):
65
- super(MultiheadA, self).__init__()
66
  self.dims = dims
67
  self.head = head
68
  self.head_dim = dims // head
@@ -71,7 +61,7 @@ class MultiheadA(nn.Module):
71
  self.v = nn.Linear(dims, dims).to(device, dtype)
72
  self.o = nn.Linear(dims, dims).to(device, dtype)
73
  self.rope = rotary(dims=dims, head=head)
74
- self.lnq = nn.LayerNorm(self.head_dim, bias = False)
75
  self.lnx = nn.LayerNorm(dims, bias = False)
76
  def forward(self, x: Tensor, xa = None, mask = None):
77
  scale = (self.dims // self.head) ** -0.25
@@ -81,15 +71,13 @@ class MultiheadA(nn.Module):
81
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
82
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
83
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
84
- q = self.lnq(q)
85
- k = self.lnq(k)
86
- q = self.rope.apply_rotary(q, (self.rope(q.shape[2]))) # type: ignore
87
- k = self.rope.apply_rotary(k, (self.rope(k.shape[2]))) # type: ignore
88
- a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and q.shape[1] > 1)
89
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
90
  return self.o(out)
91
 
92
- class t_gate(nn.Module):
93
  def __init__(self, dims, num_types=4):
94
  super().__init__()
95
  self.gate_projections = nn.ModuleList([
@@ -106,89 +94,58 @@ class t_gate(nn.Module):
106
 
107
  class Residual(nn.Module):
108
  _seen = set()
109
- def __init__(self, dims: int, head: int, ctx: int, act: str = "silu"):
110
-
111
  super().__init__()
112
-
113
- self.dims = dims
114
- self.head = head
115
- self.ctx = ctx
116
- self.head_dim = dims // head
117
  act_fn = get_activation(act)
118
  self.blend = nn.Parameter(torch.tensor(0.5))
119
- self.attn = MultiheadA(dims, head)
120
- mlp = dims * 4
121
- self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
122
- self.t_gate = t_gate(dims=dims, num_types=4*2)
123
-
124
  self.lna = nn.LayerNorm(dims, bias = False)
125
- self.lnb = nn.LayerNorm(dims, bias = False)
126
- self.lnc = nn.LayerNorm(dims, bias = False)
127
-
128
  def forward(self, x, xa=None, mask=None) -> Tensor:
129
- x = x + self.attn(self.lna(x), xa=None, mask=mask)[0]
130
- xb = x
131
  if xa is not None:
132
- x = x + self.attn(self.lnb(x), xa=xa, mask=None)[0] # type: ignore
133
  b = torch.sigmoid(self.blend)
134
  x = b * xb + (1 - b) * x
135
- normx = self.lnc(x)
136
- mlp_out = self.mlp(normx)
137
- gate = self.t_gate(normx)
138
- x = x + gate * mlp_out
139
  return x
140
 
141
  class processor(nn.Module):
142
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
143
  super(processor, self).__init__()
144
- self.dims = dims
145
- self.head = head
146
- self.layer = layer
147
- self.ctx = ctx
148
- self.act = act
149
- self.dropout = 0.01
150
  act_fn = get_activation(act)
151
-
152
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
153
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
154
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
155
  self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
156
-
157
- # pitch
158
  self.encoder = nn.Sequential(
159
  Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
160
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
161
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
162
-
163
- # self.encoder = nn.Sequential(
164
- # Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
165
- # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
166
- # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
167
-
168
- self.bA = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
169
- self.bB = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
170
-
171
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
172
  self.register_buffer("mask", mask, persistent=False)
173
  self.norm = nn.LayerNorm(dims, device=device, dtype=dtype)
174
 
175
- def forward(self, x, xa, sequential=False) -> Tensor:
176
  x = self.token(x.long()) + self.positional[:x.shape[1]]
177
-
178
  xa = self.encoder(xa).permute(0, 2, 1)
179
- xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 36000).to(device, dtype)
180
  for b in chain(self.bA or []):
181
  xa = b(x=xa, xa=None, mask=None)
182
  for b in chain(self.bB or []):
183
  x = b(x=x, xa=None, mask=self.mask)
184
  x = b(x, xa=xa, mask=None)
185
-
186
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
187
  x = self.norm(x)
188
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
189
  return x
190
 
191
- class Echo(nn.Module):
192
  def __init__(self, param: Dimensions):
193
  super().__init__()
194
  self.param = param
@@ -204,43 +161,20 @@ class Echo(nn.Module):
204
  )
205
 
206
  def forward(self,
207
- labels=None,
208
- input_ids=None,
209
- spectrogram: Optional[torch.Tensor]=None,
210
- pitch: Optional[torch.Tensor]=None,
211
- ) -> Dict[str, Optional[torch.Tensor]]:
212
-
213
- enc= {}
214
  if pitch is not None:
215
  xa = pitch
216
- if spectrogram is not None:
217
- xa = spectrogram
218
-
219
  x = input_ids
220
  logits = self.processor(x, xa)
221
-
222
  loss = None
223
  if labels is not None:
224
- loss = torch.nn.functional.cross_entropy(
225
- logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
226
  return {"logits": logits, "loss": loss}
227
 
228
- @property
229
- def device(self):
230
- return next(self.parameters()).device
231
- @property
232
- def dtype(self):
233
- return next(self.parameters()).dtype
234
-
235
  def _init_weights(self, module):
236
- std = 0.02
237
  self.init_counts = {
238
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
239
- "Conv2d": 0, "processor": 0, "Echo": 0,
240
- "Residual": 0, "MultiheadA": 0,
241
- "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
242
- "WEncoder": 0, "PEncoder": 0, "feature_encoder": 0}
243
-
244
  for name, module in self.named_modules():
245
  if isinstance(module, RMSNorm):
246
  nn.init.ones_(module.weight)
@@ -252,23 +186,21 @@ class Echo(nn.Module):
252
  nn.init.zeros_(module.bias)
253
  self.init_counts["Linear"] += 1
254
  elif isinstance(module, Conv1d):
255
- nn.init.normal_(module.weight, mean=0.0, std=std)
256
  if module.bias is not None:
257
  nn.init.zeros_(module.bias)
258
  self.init_counts["Conv1d"] += 1
259
  elif isinstance(module, Conv2d):
260
- nn.init.normal_(module.weight, mean=0.0, std=std)
261
  if module.bias is not None:
262
  nn.init.zeros_(module.bias)
263
  self.init_counts["Conv2d"] += 1
264
- elif isinstance(module, MultiheadA):
265
- self.init_counts["MultiheadA"] += 1
266
  elif isinstance(module, Residual):
267
  self.init_counts["Residual"] += 1
268
  elif isinstance(module, processor):
269
  self.init_counts["processor"] += 1
270
- elif isinstance(module, Echo):
271
- self.init_counts["Echo"] += 1
272
 
273
  def init_weights(self):
274
  print("Initializing model weights...")
@@ -282,14 +214,7 @@ def main():
282
  token = ""
283
  log_dir = os.path.join('D:/newmodel/output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
284
  os.makedirs(log_dir, exist_ok=True)
285
- tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
286
-
287
- sanity_check = False
288
- streaming = False
289
- load_saved = False
290
- save_dataset = False
291
- cache_dir = None
292
- extract_args = None
293
 
294
  extract_args = {
295
  "waveform": False,
@@ -317,71 +242,42 @@ def main():
317
  act="swish",
318
  )
319
 
320
- train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=sanity_check, sample_rate=16000, streaming=streaming,
321
- load_saved=load_saved, save_dataset=save_dataset, cache_dir=cache_dir, extract_args=extract_args, max_ctx=param.ctx)
322
 
323
- model = Echo(param).to('cuda')
324
  print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
325
  print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
326
 
327
  from functools import partial
328
- metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1,
329
- tokenizer=tokenizer, model=model)
330
-
331
- if sanity_check:
332
- training_args = Seq2SeqTrainingArguments(
333
- output_dir=log_dir,
334
- per_device_train_batch_size=1,
335
- per_device_eval_batch_size=1,
336
- max_steps=10,
337
- eval_steps=5,
338
- save_steps=0,
339
- warmup_steps=0,
340
- logging_steps=1,
341
- logging_dir=log_dir,
342
- eval_strategy="steps",
343
- save_strategy="no",
344
- logging_strategy="no",
345
- report_to=["tensorboard"],
346
- push_to_hub=False,
347
- save_total_limit=1,
348
- label_names=["labels"],
349
- save_safetensors=False,
350
- eval_on_start=False,
351
- batch_eval_metrics=False,
352
- disable_tqdm=False,
353
- include_tokens_per_second=True,
354
- include_num_input_tokens_seen=True,
355
- learning_rate=1e-7,
356
- weight_decay=0.01,
357
- )
358
- else:
359
- training_args = Seq2SeqTrainingArguments(
360
- output_dir=log_dir,
361
- per_device_train_batch_size=1,
362
- per_device_eval_batch_size=1,
363
- max_steps=1000,
364
- eval_steps=100,
365
- save_steps=1000,
366
- warmup_steps=100,
367
- logging_steps=10,
368
- logging_dir=log_dir,
369
- logging_strategy="steps",
370
- eval_strategy="steps",
371
- save_strategy="no",
372
- report_to=["tensorboard"],
373
- push_to_hub=False,
374
- save_total_limit=1,
375
- label_names=["labels"],
376
- save_safetensors=False,
377
- eval_on_start=False,
378
- batch_eval_metrics=False,
379
- disable_tqdm=False,
380
- include_tokens_per_second=True,
381
- include_num_input_tokens_seen=True,
382
- learning_rate=0.00025,
383
- weight_decay=0.025,
384
- )
385
 
386
  optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
387
  amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
 
35
  self.dims = dims
36
  self.head = head
37
  self.head_dim = dims // head
38
+ self.theta = nn.Parameter((torch.tensor(10000, device=device, dtype=dtype)), requires_grad=True)
39
+ def forward(self, x, ctx) -> Tensor:
40
+ freqs = (self.theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
41
+ t = torch.arange(ctx, device=device, dtype=dtype)
 
 
 
 
42
  freqs = t[:, None] * freqs
43
  freqs=torch.polar(torch.ones_like(freqs), freqs)
 
 
 
 
44
  x1 = x[..., :freqs.shape[-1]*2]
45
  x2 = x[..., freqs.shape[-1]*2:]
46
  orig_shape = x1.shape
 
 
47
  x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
48
  x1 = torch.view_as_complex(x1) * freqs
49
  x1 = torch.view_as_real(x1).flatten(-2)
50
  x1 = x1.view(orig_shape)
51
  return torch.cat([x1.type_as(x), x2], dim=-1)
52
 
53
+ class attention(nn.Module):
54
  def __init__(self, dims: int, head: int):
55
+ super(attention, self).__init__()
56
  self.dims = dims
57
  self.head = head
58
  self.head_dim = dims // head
 
61
  self.v = nn.Linear(dims, dims).to(device, dtype)
62
  self.o = nn.Linear(dims, dims).to(device, dtype)
63
  self.rope = rotary(dims=dims, head=head)
64
+ self.lny = nn.LayerNorm(self.head_dim, bias = False)
65
  self.lnx = nn.LayerNorm(dims, bias = False)
66
  def forward(self, x: Tensor, xa = None, mask = None):
67
  scale = (self.dims // self.head) ** -0.25
 
71
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
72
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
73
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
74
+ q = self.rope(q, q.shape[2])
75
+ k = self.rope(k, k.shape[2])
76
+ a = scaled_dot_product_attention(self.lny(q), self.lny(k), v, is_causal=mask is not None and q.shape[1] > 1)
 
 
77
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
78
  return self.o(out)
79
 
80
+ class tgate(nn.Module):
81
  def __init__(self, dims, num_types=4):
82
  super().__init__()
83
  self.gate_projections = nn.ModuleList([
 
94
 
95
  class Residual(nn.Module):
96
  _seen = set()
97
+ def __init__(self, dims: int, head: int, act: str = "silu"):
 
98
  super().__init__()
 
 
 
 
 
99
  act_fn = get_activation(act)
100
  self.blend = nn.Parameter(torch.tensor(0.5))
101
+ self.attn = attention(dims, head)
102
+ self.mlp = nn.Sequential(Linear(dims, dims*4), act_fn, Linear(dims*4, dims))
103
+ self.tgate = tgate(dims=dims, num_types=4*2)
 
 
104
  self.lna = nn.LayerNorm(dims, bias = False)
 
 
 
105
  def forward(self, x, xa=None, mask=None) -> Tensor:
106
+ xb = x + self.attn(self.lna(x), xa=None, mask=mask)[0]
 
107
  if xa is not None:
108
+ x = x + self.attn(self.lna(x), xa=xa, mask=None)[0]
109
  b = torch.sigmoid(self.blend)
110
  x = b * xb + (1 - b) * x
111
+ out = self.mlp(self.lna(x))
112
+ gate = self.tgate(self.lna(x))
113
+ x = x + gate * out
 
114
  return x
115
 
116
  class processor(nn.Module):
117
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
118
  super(processor, self).__init__()
 
 
 
 
 
 
119
  act_fn = get_activation(act)
 
120
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
121
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
122
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
123
  self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
 
 
124
  self.encoder = nn.Sequential(
125
  Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
126
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
127
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
128
+ self.bA = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
129
+ self.bB = nn.ModuleList([Residual(dims=dims, head=head, act=act_fn) for _ in range(layer)])
 
 
 
 
 
 
 
130
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
131
  self.register_buffer("mask", mask, persistent=False)
132
  self.norm = nn.LayerNorm(dims, device=device, dtype=dtype)
133
 
134
+ def forward(self, x, xa) -> Tensor:
135
  x = self.token(x.long()) + self.positional[:x.shape[1]]
 
136
  xa = self.encoder(xa).permute(0, 2, 1)
137
+ xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 10000.0).to(device, dtype)
138
  for b in chain(self.bA or []):
139
  xa = b(x=xa, xa=None, mask=None)
140
  for b in chain(self.bB or []):
141
  x = b(x=x, xa=None, mask=self.mask)
142
  x = b(x, xa=xa, mask=None)
143
+ x = nn.functional.dropout(x, p=0.001, training=self.training)
 
144
  x = self.norm(x)
145
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
146
  return x
147
 
148
+ class Model(nn.Module):
149
  def __init__(self, param: Dimensions):
150
  super().__init__()
151
  self.param = param
 
161
  )
162
 
163
  def forward(self,
164
+ labels=None, input_ids=None, pitch: Optional[torch.Tensor]=None) -> Dict[str, Optional[torch.Tensor]]:
 
 
 
 
 
 
165
  if pitch is not None:
166
  xa = pitch
 
 
 
167
  x = input_ids
168
  logits = self.processor(x, xa)
 
169
  loss = None
170
  if labels is not None:
171
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1))
 
172
  return {"logits": logits, "loss": loss}
173
 
 
 
 
 
 
 
 
174
  def _init_weights(self, module):
 
175
  self.init_counts = {
176
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
177
+ "Conv2d": 0, "processor": 0, "attention": 0, "Residual": 0}
 
 
 
 
178
  for name, module in self.named_modules():
179
  if isinstance(module, RMSNorm):
180
  nn.init.ones_(module.weight)
 
186
  nn.init.zeros_(module.bias)
187
  self.init_counts["Linear"] += 1
188
  elif isinstance(module, Conv1d):
189
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
190
  if module.bias is not None:
191
  nn.init.zeros_(module.bias)
192
  self.init_counts["Conv1d"] += 1
193
  elif isinstance(module, Conv2d):
194
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
195
  if module.bias is not None:
196
  nn.init.zeros_(module.bias)
197
  self.init_counts["Conv2d"] += 1
198
+ elif isinstance(module, attention):
199
+ self.init_counts["attention"] += 1
200
  elif isinstance(module, Residual):
201
  self.init_counts["Residual"] += 1
202
  elif isinstance(module, processor):
203
  self.init_counts["processor"] += 1
 
 
204
 
205
  def init_weights(self):
206
  print("Initializing model weights...")
 
214
  token = ""
215
  log_dir = os.path.join('D:/newmodel/output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
216
  os.makedirs(log_dir, exist_ok=True)
217
+ tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
 
 
 
 
 
 
 
218
 
219
  extract_args = {
220
  "waveform": False,
 
242
  act="swish",
243
  )
244
 
245
+ train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
246
+ load_saved=False, save_dataset=False, cache_dir=None, extract_args=None, max_ctx=param.ctx)
247
 
248
+ model = Model(param).to('cuda')
249
  print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
250
  print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
251
 
252
  from functools import partial
253
+ metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
254
+
255
+ training_args = Seq2SeqTrainingArguments(
256
+ output_dir=log_dir,
257
+ per_device_train_batch_size=1,
258
+ per_device_eval_batch_size=1,
259
+ max_steps=1000,
260
+ eval_steps=100,
261
+ save_steps=1000,
262
+ warmup_steps=100,
263
+ logging_steps=10,
264
+ logging_dir=log_dir,
265
+ logging_strategy="steps",
266
+ eval_strategy="steps",
267
+ save_strategy="no",
268
+ report_to=["tensorboard"],
269
+ push_to_hub=False,
270
+ save_total_limit=1,
271
+ label_names=["labels"],
272
+ save_safetensors=False,
273
+ eval_on_start=False,
274
+ batch_eval_metrics=False,
275
+ disable_tqdm=False,
276
+ include_tokens_per_second=True,
277
+ include_num_input_tokens_seen=True,
278
+ learning_rate=0.00025,
279
+ weight_decay=0.025,
280
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
283
  amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)