Update model_simple.py
Browse files- model_simple.py +4 -3
model_simple.py
CHANGED
@@ -264,8 +264,9 @@ class processor(nn.Module):
|
|
264 |
xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
265 |
|
266 |
for b in chain(self.bA or []):
|
267 |
-
xa = b(
|
268 |
-
x
|
|
|
269 |
xc = b(torch.cat([x, xa], dim=1), xa=None, mask=self.mask) if modal else None
|
270 |
x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
|
271 |
|
@@ -340,4 +341,4 @@ class Model(nn.Module):
|
|
340 |
print("Initialization summary:")
|
341 |
for module_type, count in self.init_counts.items():
|
342 |
if count > 0:
|
343 |
-
print(f"{module_type}: {count}")
|
|
|
264 |
xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
265 |
|
266 |
for b in chain(self.bA or []):
|
267 |
+
xa = b(x=xa, xa=None, mask=None)
|
268 |
+
x = b(x=x, xa=None, mask=self.mask)
|
269 |
+
x = b(x=x, xa=xa, mask=None)
|
270 |
xc = b(torch.cat([x, xa], dim=1), xa=None, mask=self.mask) if modal else None
|
271 |
x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
|
272 |
|
|
|
341 |
print("Initialization summary:")
|
342 |
for module_type, count in self.init_counts.items():
|
343 |
if count > 0:
|
344 |
+
print(f"{module_type}: {count}")
|