nppmatt commited on
Commit
c2dd68b
1 Parent(s): 408fdfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -91,11 +91,11 @@ class PretrainedBertClass(torch.nn.Module):
91
  def __init__(self):
92
  super(PretrainedBertClass, self).__init__()
93
  self.l1 = BertModel.from_pretrained(bert_path)
94
- self.l2 = torch.nn.Dropout(0.3)
95
  self.l3 = torch.nn.Linear(768, 6)
96
 
97
  def forward(self, ids, mask, token_type_ids):
98
- _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids)
99
  output_2 = self.l2(output_1)
100
  output = self.l3(output_2)
101
  return output
 
91
  def __init__(self):
92
  super(PretrainedBertClass, self).__init__()
93
  self.l1 = BertModel.from_pretrained(bert_path)
94
+ self.l2 = torch.nn.Dropout(HEAD_DROP_OUT)
95
  self.l3 = torch.nn.Linear(768, 6)
96
 
97
  def forward(self, ids, mask, token_type_ids):
98
+ _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids, return_dict=False)
99
  output_2 = self.l2(output_1)
100
  output = self.l3(output_2)
101
  return output