nppmatt commited on
Commit
28af2c4
·
1 Parent(s): dc98418

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -17
app.py CHANGED
@@ -66,22 +66,6 @@ infer_dataset = ToxicityDataset(tweet_df, tokenizer, MAX_LENGTH)
66
  infer_params = {"batch_size": INFER_BATCH_SIZE, "shuffle": False}
67
  infer_loader = DataLoader(infer_dataset, **infer_params)
68
 
69
- # Have data for BertClass ready for both models
70
- class BertClass(torch.nn.Module):
71
- def __init__(self):
72
- super(BertClass, self).__init__()
73
- self.l1 = torch.load("pytorch_bert_toxic.bin", map_location=torch.device(device))
74
- self.dropout = torch.nn.Dropout(HEAD_DROP_OUT)
75
- self.classifier = torch.nn.Linear(768, 6)
76
-
77
- def forward(self, input_ids, attention_mask, token_type_ids):
78
- output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
79
- hidden_state = output_1[0]
80
- pooler = hidden_state[:, 0]
81
- pooler = self.dropout(pooler)
82
- output = self.classifier(pooler)
83
- return output
84
-
85
  class PretrainedBertClass(torch.nn.Module):
86
  def __init__(self):
87
  super(PretrainedBertClass, self).__init__()
@@ -96,7 +80,7 @@ option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT
96
  if option == "BERT":
97
  model = PretrainedBertClass()
98
  else:
99
- model = BertClass()
100
 
101
  # Freeze model and input tokens
102
  def inference():
 
66
  infer_params = {"batch_size": INFER_BATCH_SIZE, "shuffle": False}
67
  infer_loader = DataLoader(infer_dataset, **infer_params)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class PretrainedBertClass(torch.nn.Module):
70
  def __init__(self):
71
  super(PretrainedBertClass, self).__init__()
 
80
  if option == "BERT":
81
  model = PretrainedBertClass()
82
  else:
83
+ model = torch.load("pytorch_bert_toxic.bin")
84
 
85
  # Freeze model and input tokens
86
  def inference():