nppmatt commited on
Commit
d97d18d
·
1 Parent(s): 26c2ddd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -36
app.py CHANGED
@@ -10,31 +10,6 @@ import streamlit as st
10
  # Define Torch device. Enable CUDA if available.
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # Have data for BertClass ready for both models
14
- class BertClass(torch.nn.Module):
15
- def __init__(self):
16
- super(BertClass, self).__init__()
17
- self.l1 = BertModel.from_pretrained(model_path)
18
- self.dropout = torch.nn.Dropout(HEAD_DROP_OUT)
19
- self.classifier = torch.nn.Linear(768, 6)
20
-
21
- def forward(self, input_ids, attention_mask, token_type_ids):
22
- output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
23
- hidden_state = output_1[0]
24
- pooler = hidden_state[:, 0]
25
- pooler = self.dropout(pooler)
26
- output = self.classifier(pooler)
27
- return output
28
-
29
- class PretrainedBertClass(torch.nn.Module):
30
- def __init__(self):
31
- super(BertClass, self).__init__()
32
- self.l1 = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
33
-
34
- def forward(self, input_ids, attention_mask, token_type_ids):
35
- output = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
36
- return output
37
-
38
  # Read and format data.
39
  tweets_raw = pd.read_csv("test.csv", nrows=20)
40
  labels_raw = pd.read_csv("test_labels.csv", nrows=20)
@@ -45,17 +20,6 @@ label_vector = labels_raw[label_set].values.tolist()
45
  tweet_df = tweets_raw[["comment_text"]]
46
  tweet_df["labels"] = label_vector
47
 
48
- # User selects model for front-end.
49
- option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT"))
50
-
51
- bert_path = "bert-base-uncased"
52
- if option == "BERT":
53
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
54
- model = PretrainedBertClass()
55
- else:
56
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
57
- model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device(device))
58
-
59
  # Dataset for loading tables into DataLoader
60
  class ToxicityDataset(Dataset):
61
  def __init__(self, dataframe, tokenizer, max_len):
@@ -94,10 +58,47 @@ class ToxicityDataset(Dataset):
94
  # Based on user model selection, prepare Dataset and DataLoader
95
  MAX_LENGTH = 100
96
  INFER_BATCH_SIZE = 128
 
97
  infer_dataset = ToxicityDataset(tweet_df, tokenizer, MAX_LENGTH)
98
  infer_params = {"batch_size": INFER_BATCH_SIZE, "shuffle": False}
99
  infer_loader = DataLoader(infer_dataset, **infer_params)
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Freeze model and input tokens
102
  def inference():
103
  model.eval()
 
10
  # Define Torch device. Enable CUDA if available.
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Read and format data.
14
  tweets_raw = pd.read_csv("test.csv", nrows=20)
15
  labels_raw = pd.read_csv("test_labels.csv", nrows=20)
 
20
  tweet_df = tweets_raw[["comment_text"]]
21
  tweet_df["labels"] = label_vector
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Dataset for loading tables into DataLoader
24
  class ToxicityDataset(Dataset):
25
  def __init__(self, dataframe, tokenizer, max_len):
 
58
  # Based on user model selection, prepare Dataset and DataLoader
59
  MAX_LENGTH = 100
60
  INFER_BATCH_SIZE = 128
61
+ HEAD_DROP_OUT = 0.4
62
  infer_dataset = ToxicityDataset(tweet_df, tokenizer, MAX_LENGTH)
63
  infer_params = {"batch_size": INFER_BATCH_SIZE, "shuffle": False}
64
  infer_loader = DataLoader(infer_dataset, **infer_params)
65
 
66
+ # Have data for BertClass ready for both models
67
+ class BertClass(torch.nn.Module):
68
+ def __init__(self):
69
+ super(BertClass, self).__init__()
70
+ self.l1 = torch.load("pytorch_bert_toxic.bin", map_location=torch.device(device))
71
+ self.dropout = torch.nn.Dropout(HEAD_DROP_OUT)
72
+ self.classifier = torch.nn.Linear(768, 6)
73
+
74
+ def forward(self, input_ids, attention_mask, token_type_ids):
75
+ output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
76
+ hidden_state = output_1[0]
77
+ pooler = hidden_state[:, 0]
78
+ pooler = self.dropout(pooler)
79
+ output = self.classifier(pooler)
80
+ return output
81
+
82
+ class PretrainedBertClass(torch.nn.Module):
83
+ def __init__(self):
84
+ super(PretrainedBertClass, self).__init__()
85
+ self.l1 = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
86
+
87
+ def forward(self, input_ids, attention_mask, token_type_ids):
88
+ output = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
89
+ return output
90
+
91
+ # User selects model for front-end.
92
+ option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT"))
93
+
94
+ bert_path = "bert-base-uncased"
95
+ if option == "BERT":
96
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
97
+ model = PretrainedBertClass()
98
+ else:
99
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
100
+ model = BertClass()
101
+
102
  # Freeze model and input tokens
103
  def inference():
104
  model.eval()