Sin2pi commited on
Commit
3c6d395
·
verified ·
1 Parent(s): 474e71e

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +2 -8
model_simple.py CHANGED
@@ -269,8 +269,6 @@ class processor(nn.Module):
269
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
270
  super(processor, self).__init__()
271
 
272
- self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
273
- self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
274
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
275
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
276
  self.posin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
@@ -286,6 +284,7 @@ class processor(nn.Module):
286
 
287
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
288
  self.register_buffer("mask", mask, persistent=False)
 
289
 
290
  def forward(self, x, xa, sequential=False) -> Tensor:
291
 
@@ -298,12 +297,7 @@ class processor(nn.Module):
298
 
299
  for b in chain(self.bB or []):
300
  x = b(x=x, xa=None, mask=self.mask)
301
- y = b(x, xa=xa, mask=None)
302
- if sequential:
303
- x = y
304
- else:
305
- a = torch.sigmoid(self.blend)
306
- x = a * y + (1 - a) * x
307
 
308
  x = nn.functional.dropout(x, p=0.001, training=self.training)
309
  x = self.ln(x)
 
269
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
270
  super(processor, self).__init__()
271
 
 
 
272
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
273
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
274
  self.posin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
 
284
 
285
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
286
  self.register_buffer("mask", mask, persistent=False)
287
+ self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
288
 
289
  def forward(self, x, xa, sequential=False) -> Tensor:
290
 
 
297
 
298
  for b in chain(self.bB or []):
299
  x = b(x=x, xa=None, mask=self.mask)
300
+ x = b(x, xa=xa, mask=None)
 
 
 
 
 
301
 
302
  x = nn.functional.dropout(x, p=0.001, training=self.training)
303
  x = self.ln(x)