nppmatt commited on
Commit
93270ec
·
1 Parent(s): 22c2fe3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -25
app.py CHANGED
@@ -7,28 +7,8 @@ from transformers import AutoTokenizer, BertModel, BertForSequenceClassification
7
  from sklearn import metrics
8
  import streamlit as st
9
 
10
- # Have data for BertClass ready for our tuned model.
11
- class BertClass(torch.nn.Module):
12
- def __init__(self):
13
- super(BertClass, self).__init__()
14
- self.l1 = BertModel.from_pretrained(model_path)
15
- self.dropout = torch.nn.Dropout(HEAD_DROP_OUT)
16
- self.classifier = torch.nn.Linear(768, 6)
17
-
18
- def forward(self, input_ids, attention_mask, token_type_ids):
19
- output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
20
- hidden_state = output_1[0]
21
- pooler = hidden_state[:, 0]
22
- pooler = self.dropout(pooler)
23
- output = self.classifier(pooler)
24
- return output
25
-
26
- # Define models to be used
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
- bert_path = "bert-base-uncased"
29
- bert_tokenizer = AutoTokenizer.from_pretrained(bert_path)
30
- bert_model = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
31
- tuned_model = model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device(device))
32
 
33
  # Read and format data.
34
  tweets_raw = pd.read_csv("test.csv", nrows=20)
@@ -42,12 +22,30 @@ tweet_df["labels"] = label_vector
42
 
43
  # User selects model for front-end.
44
  option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT"))
 
 
45
  if option == "BERT":
46
- tokenizer = bert_tokenizer
47
- model = bert_model
48
  else:
49
- tokenizer = bert_tokenizer
50
- model = tuned_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Dataset for loading tables into DataLoader
53
  class ToxicityDataset(Dataset):
 
7
  from sklearn import metrics
8
  import streamlit as st
9
 
10
+ # Define Torch device. Enable CUDA if available.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
12
 
13
  # Read and format data.
14
  tweets_raw = pd.read_csv("test.csv", nrows=20)
 
22
 
23
  # User selects model for front-end.
24
  option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT"))
25
+
26
+ bert_path = "bert-base-uncased"
27
  if option == "BERT":
28
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
29
+ model = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
30
  else:
31
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
32
+ model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device(device))
33
+
34
+ # Have data for BertClass ready for our tuned model.
35
+ class BertClass(torch.nn.Module):
36
+ def __init__(self):
37
+ super(BertClass, self).__init__()
38
+ self.l1 = BertModel.from_pretrained(model_path)
39
+ self.dropout = torch.nn.Dropout(HEAD_DROP_OUT)
40
+ self.classifier = torch.nn.Linear(768, 6)
41
+
42
+ def forward(self, input_ids, attention_mask, token_type_ids):
43
+ output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
44
+ hidden_state = output_1[0]
45
+ pooler = hidden_state[:, 0]
46
+ pooler = self.dropout(pooler)
47
+ output = self.classifier(pooler)
48
+ return output
49
 
50
  # Dataset for loading tables into DataLoader
51
  class ToxicityDataset(Dataset):