nppmatt commited on
Commit
dc98418
1 Parent(s): d97d18d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -7,8 +7,14 @@ from transformers import AutoTokenizer, BertModel, BertForSequenceClassification
7
  from sklearn import metrics
8
  import streamlit as st
9
 
10
- # Define Torch device. Enable CUDA if available.
 
 
 
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
12
 
13
  # Read and format data.
14
  tweets_raw = pd.read_csv("test.csv", nrows=20)
@@ -56,9 +62,6 @@ class ToxicityDataset(Dataset):
56
  }
57
 
58
  # Based on user model selection, prepare Dataset and DataLoader
59
- MAX_LENGTH = 100
60
- INFER_BATCH_SIZE = 128
61
- HEAD_DROP_OUT = 0.4
62
  infer_dataset = ToxicityDataset(tweet_df, tokenizer, MAX_LENGTH)
63
  infer_params = {"batch_size": INFER_BATCH_SIZE, "shuffle": False}
64
  infer_loader = DataLoader(infer_dataset, **infer_params)
@@ -90,13 +93,9 @@ class PretrainedBertClass(torch.nn.Module):
90
 
91
  # User selects model for front-end.
92
  option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT"))
93
-
94
- bert_path = "bert-base-uncased"
95
  if option == "BERT":
96
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
97
  model = PretrainedBertClass()
98
  else:
99
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
100
  model = BertClass()
101
 
102
  # Freeze model and input tokens
 
7
  from sklearn import metrics
8
  import streamlit as st
9
 
10
+ # Define constants. Enable CUDA if available.
11
+ MAX_LENGTH = 100
12
+ INFER_BATCH_SIZE = 128
13
+ HEAD_DROP_OUT = 0.4
14
+
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ bert_path = "bert-base-uncased"
17
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
18
 
19
  # Read and format data.
20
  tweets_raw = pd.read_csv("test.csv", nrows=20)
 
62
  }
63
 
64
  # Based on user model selection, prepare Dataset and DataLoader
 
 
 
65
  infer_dataset = ToxicityDataset(tweet_df, tokenizer, MAX_LENGTH)
66
  infer_params = {"batch_size": INFER_BATCH_SIZE, "shuffle": False}
67
  infer_loader = DataLoader(infer_dataset, **infer_params)
 
93
 
94
  # User selects model for front-end.
95
  option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT"))
 
 
96
  if option == "BERT":
 
97
  model = PretrainedBertClass()
98
  else:
 
99
  model = BertClass()
100
 
101
  # Freeze model and input tokens