Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -91,11 +91,11 @@ class PretrainedBertClass(torch.nn.Module):
|
|
91 |
def __init__(self):
|
92 |
super(PretrainedBertClass, self).__init__()
|
93 |
self.l1 = BertModel.from_pretrained(bert_path)
|
94 |
-
self.l2 = torch.nn.Dropout(
|
95 |
self.l3 = torch.nn.Linear(768, 6)
|
96 |
|
97 |
def forward(self, ids, mask, token_type_ids):
|
98 |
-
_, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids)
|
99 |
output_2 = self.l2(output_1)
|
100 |
output = self.l3(output_2)
|
101 |
return output
|
|
|
91 |
def __init__(self):
|
92 |
super(PretrainedBertClass, self).__init__()
|
93 |
self.l1 = BertModel.from_pretrained(bert_path)
|
94 |
+
self.l2 = torch.nn.Dropout(HEAD_DROP_OUT)
|
95 |
self.l3 = torch.nn.Linear(768, 6)
|
96 |
|
97 |
def forward(self, ids, mask, token_type_ids):
|
98 |
+
_, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids, return_dict=False)
|
99 |
output_2 = self.l2(output_1)
|
100 |
output = self.l3(output_2)
|
101 |
return output
|