nppmatt commited on
Commit
26c2ddd
·
1 Parent(s): 93270ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -10,6 +10,31 @@ import streamlit as st
10
  # Define Torch device. Enable CUDA if available.
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Read and format data.
14
  tweets_raw = pd.read_csv("test.csv", nrows=20)
15
  labels_raw = pd.read_csv("test_labels.csv", nrows=20)
@@ -26,27 +51,11 @@ option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT
26
  bert_path = "bert-base-uncased"
27
  if option == "BERT":
28
  tokenizer = AutoTokenizer.from_pretrained(bert_path)
29
- model = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
30
  else:
31
  tokenizer = AutoTokenizer.from_pretrained(bert_path)
32
  model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device(device))
33
 
34
- # Have data for BertClass ready for our tuned model.
35
- class BertClass(torch.nn.Module):
36
- def __init__(self):
37
- super(BertClass, self).__init__()
38
- self.l1 = BertModel.from_pretrained(model_path)
39
- self.dropout = torch.nn.Dropout(HEAD_DROP_OUT)
40
- self.classifier = torch.nn.Linear(768, 6)
41
-
42
- def forward(self, input_ids, attention_mask, token_type_ids):
43
- output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
44
- hidden_state = output_1[0]
45
- pooler = hidden_state[:, 0]
46
- pooler = self.dropout(pooler)
47
- output = self.classifier(pooler)
48
- return output
49
-
50
  # Dataset for loading tables into DataLoader
51
  class ToxicityDataset(Dataset):
52
  def __init__(self, dataframe, tokenizer, max_len):
 
10
  # Define Torch device. Enable CUDA if available.
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Have data for BertClass ready for both models
14
+ class BertClass(torch.nn.Module):
15
+ def __init__(self):
16
+ super(BertClass, self).__init__()
17
+ self.l1 = BertModel.from_pretrained(model_path)
18
+ self.dropout = torch.nn.Dropout(HEAD_DROP_OUT)
19
+ self.classifier = torch.nn.Linear(768, 6)
20
+
21
+ def forward(self, input_ids, attention_mask, token_type_ids):
22
+ output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
23
+ hidden_state = output_1[0]
24
+ pooler = hidden_state[:, 0]
25
+ pooler = self.dropout(pooler)
26
+ output = self.classifier(pooler)
27
+ return output
28
+
29
+ class PretrainedBertClass(torch.nn.Module):
30
+ def __init__(self):
31
+ super(BertClass, self).__init__()
32
+ self.l1 = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
33
+
34
+ def forward(self, input_ids, attention_mask, token_type_ids):
35
+ output = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
36
+ return output
37
+
38
  # Read and format data.
39
  tweets_raw = pd.read_csv("test.csv", nrows=20)
40
  labels_raw = pd.read_csv("test_labels.csv", nrows=20)
 
51
  bert_path = "bert-base-uncased"
52
  if option == "BERT":
53
  tokenizer = AutoTokenizer.from_pretrained(bert_path)
54
+ model = PretrainedBertClass()
55
  else:
56
  tokenizer = AutoTokenizer.from_pretrained(bert_path)
57
  model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device(device))
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # Dataset for loading tables into DataLoader
60
  class ToxicityDataset(Dataset):
61
  def __init__(self, dataframe, tokenizer, max_len):