Update model_simple.py
Browse files- model_simple.py +2 -8
model_simple.py
CHANGED
@@ -269,8 +269,6 @@ class processor(nn.Module):
|
|
269 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
270 |
super(processor, self).__init__()
|
271 |
|
272 |
-
self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
|
273 |
-
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
274 |
self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
|
275 |
self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
|
276 |
self.posin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
@@ -286,6 +284,7 @@ class processor(nn.Module):
|
|
286 |
|
287 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
288 |
self.register_buffer("mask", mask, persistent=False)
|
|
|
289 |
|
290 |
def forward(self, x, xa, sequential=False) -> Tensor:
|
291 |
|
@@ -298,12 +297,7 @@ class processor(nn.Module):
|
|
298 |
|
299 |
for b in chain(self.bB or []):
|
300 |
x = b(x=x, xa=None, mask=self.mask)
|
301 |
-
|
302 |
-
if sequential:
|
303 |
-
x = y
|
304 |
-
else:
|
305 |
-
a = torch.sigmoid(self.blend)
|
306 |
-
x = a * y + (1 - a) * x
|
307 |
|
308 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
309 |
x = self.ln(x)
|
|
|
269 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
270 |
super(processor, self).__init__()
|
271 |
|
|
|
|
|
272 |
self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
|
273 |
self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
|
274 |
self.posin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
|
|
284 |
|
285 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
286 |
self.register_buffer("mask", mask, persistent=False)
|
287 |
+
self.ln = nn.LayerNorm(dims, device=device, dtype=dtype)
|
288 |
|
289 |
def forward(self, x, xa, sequential=False) -> Tensor:
|
290 |
|
|
|
297 |
|
298 |
for b in chain(self.bB or []):
|
299 |
x = b(x=x, xa=None, mask=self.mask)
|
300 |
+
x = b(x, xa=xa, mask=None)
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
303 |
x = self.ln(x)
|