oucgc1996 commited on
Commit
6caa644
·
verified ·
1 Parent(s): 44bd3fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -82,7 +82,7 @@ class AttentionBlock(nn.Module):
82
 
83
  self.do = nn.Dropout(dropout)
84
 
85
- self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).cuda()
86
 
87
  def forward(self, query, key, value, mask=None):
88
  batch_size = query.shape[0]
 
82
 
83
  self.do = nn.Dropout(dropout)
84
 
85
+ self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)
86
 
87
  def forward(self, query, key, value, mask=None):
88
  batch_size = query.shape[0]