nppmatt commited on
Commit
73571fc
·
1 Parent(s): 28af2c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -0
app.py CHANGED
@@ -66,6 +66,21 @@ infer_dataset = ToxicityDataset(tweet_df, tokenizer, MAX_LENGTH)
66
  infer_params = {"batch_size": INFER_BATCH_SIZE, "shuffle": False}
67
  infer_loader = DataLoader(infer_dataset, **infer_params)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class PretrainedBertClass(torch.nn.Module):
70
  def __init__(self):
71
  super(PretrainedBertClass, self).__init__()
 
66
  infer_params = {"batch_size": INFER_BATCH_SIZE, "shuffle": False}
67
  infer_loader = DataLoader(infer_dataset, **infer_params)
68
 
69
+ class BertClass(torch.nn.Module):
70
+ def __init__(self):
71
+ super(BertClass, self).__init__()
72
+ self.l1 = BertModel.from_pretrained(bert_path)
73
+ self.dropout = torch.nn.Dropout(HEAD_DROP_OUT)
74
+ self.classifier = torch.nn.Linear(768, 6)
75
+
76
+ def forward(self, input_ids, attention_mask, token_type_ids):
77
+ output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
78
+ hidden_state = output_1[0]
79
+ pooler = hidden_state[:, 0]
80
+ pooler = self.dropout(pooler)
81
+ output = self.classifier(pooler)
82
+ return output
83
+
84
  class PretrainedBertClass(torch.nn.Module):
85
  def __init__(self):
86
  super(PretrainedBertClass, self).__init__()