Sin2pi commited on
Commit
fd4647c
·
verified ·
1 Parent(s): 446e362

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +4 -3
model_simple.py CHANGED
@@ -264,8 +264,9 @@ class processor(nn.Module):
264
  xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
265
 
266
  for b in chain(self.bA or []):
267
- xa = b(self.lna(xa))
268
- x = b(self.lnb(x), xa=xa, mask=self.mask)
 
269
  xc = b(torch.cat([x, xa], dim=1), xa=None, mask=self.mask) if modal else None
270
  x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
271
 
@@ -340,4 +341,4 @@ class Model(nn.Module):
340
  print("Initialization summary:")
341
  for module_type, count in self.init_counts.items():
342
  if count > 0:
343
- print(f"{module_type}: {count}")
 
264
  xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
265
 
266
  for b in chain(self.bA or []):
267
+ xa = b(x=xa, xa=None, mask=None)
268
+ x = b(x=x, xa=None, mask=self.mask)
269
+ x = b(x=x, xa=xa, mask=None)
270
  xc = b(torch.cat([x, xa], dim=1), xa=None, mask=self.mask) if modal else None
271
  x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
272
 
 
341
  print("Initialization summary:")
342
  for module_type, count in self.init_counts.items():
343
  if count > 0:
344
+ print(f"{module_type}: {count}")