nppmatt commited on
Commit
a01490b
1 Parent(s): 6407600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -21
app.py CHANGED
@@ -67,28 +67,40 @@ infer_params = {"batch_size": INFER_BATCH_SIZE, "shuffle": False}
67
  infer_loader = DataLoader(infer_dataset, **infer_params)
68
 
69
  class BertClass(torch.nn.Module):
70
- def __init__(self):
71
- super(BertClass, self).__init__()
72
- self.l1 = BertModel.from_pretrained(bert_path)
73
- self.dropout = torch.nn.Dropout(HEAD_DROP_OUT)
74
- self.classifier = torch.nn.Linear(768, 6)
75
-
76
- def forward(self, input_ids, attention_mask, token_type_ids):
77
- output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
78
- hidden_state = output_1[0]
79
- pooler = hidden_state[:, 0]
80
- pooler = self.dropout(pooler)
81
- output = self.classifier(pooler)
82
- return output
 
 
 
 
 
 
83
 
84
  class PretrainedBertClass(torch.nn.Module):
85
- def __init__(self):
86
- super(PretrainedBertClass, self).__init__()
87
- self.l1 = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
88
-
89
- def forward(self, input_ids, attention_mask, token_type_ids):
90
- output = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
91
- return output
 
 
 
 
 
 
92
 
93
  # User selects model for front-end.
94
  option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT"))
@@ -108,11 +120,12 @@ def inference():
108
  mask = data["mask"].to(device, dtype=torch.long)
109
  token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
110
  targets = data["targets"].to(device, dtype=torch.float)
111
- outputs = model(ids, mask, token_type_ids, return_dict=False)
112
  final_targets.extend(targets.cpu().detach().numpy().tolist())
113
  final_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
114
  return final_outputs, final_targets
115
 
 
116
  prediction, targets = inference()
117
  prediction = np.array(prediction) >= 0.5
118
  targets = np.argmax(targets, axis=1)
 
67
  infer_loader = DataLoader(infer_dataset, **infer_params)
68
 
69
  class BertClass(torch.nn.Module):
70
+ def __init__(self):
71
+ super(BertClass, self).__init__()
72
+ self.l1 = BertModel.from_pretrained(bert_path)
73
+ self.dropout = torch.nn.Dropout(HEAD_DROP_OUT)
74
+ self.classifier = torch.nn.Linear(768, 6)
75
+
76
+ # return_dict must equal False for Huggingface Transformers v4+
77
+ def forward(self, input_ids, attention_mask, token_type_ids):
78
+ output_1 = self.l1(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ token_type_ids=token_type_ids,
82
+ return_dict=False,
83
+ )
84
+ hidden_state = output_1[0]
85
+ pooler = hidden_state[:, 0]
86
+ pooler = self.dropout(pooler)
87
+ output = self.classifier(pooler)
88
+ return output
89
 
90
  class PretrainedBertClass(torch.nn.Module):
91
+ def __init__(self):
92
+ super(PretrainedBertClass, self).__init__()
93
+ self.l1 = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
94
+
95
+ # return_dict must equal False for Huggingface Transformers v4+
96
+ def forward(self, input_ids, attention_mask, token_type_ids):
97
+ output = self.l1(
98
+ input_ids=input_ids,
99
+ attention_mask=attention_mask,
100
+ token_type_ids=token_type_ids,
101
+ return_dict=False,
102
+ )
103
+ return output
104
 
105
  # User selects model for front-end.
106
  option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT"))
 
120
  mask = data["mask"].to(device, dtype=torch.long)
121
  token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
122
  targets = data["targets"].to(device, dtype=torch.float)
123
+ outputs = model(ids, mask, token_type_ids)
124
  final_targets.extend(targets.cpu().detach().numpy().tolist())
125
  final_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
126
  return final_outputs, final_targets
127
 
128
+
129
  prediction, targets = inference()
130
  prediction = np.array(prediction) >= 0.5
131
  targets = np.argmax(targets, axis=1)