Update app.py
Browse files
app.py
CHANGED
@@ -82,7 +82,7 @@ class AttentionBlock(nn.Module):
|
|
82 |
|
83 |
self.do = nn.Dropout(dropout)
|
84 |
|
85 |
-
self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).
|
86 |
|
87 |
def forward(self, query, key, value, mask=None):
|
88 |
batch_size = query.shape[0]
|
|
|
82 |
|
83 |
self.do = nn.Dropout(dropout)
|
84 |
|
85 |
+
self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)
|
86 |
|
87 |
def forward(self, query, key, value, mask=None):
|
88 |
batch_size = query.shape[0]
|