nppmatt commited on
Commit
22c2fe3
·
1 Parent(s): 47c22ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -24,10 +24,11 @@ class BertClass(torch.nn.Module):
24
  return output
25
 
26
  # Define models to be used
 
27
  bert_path = "bert-base-uncased"
28
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_path)
29
  bert_model = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
30
- tuned_model = model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device("cpu"))
31
 
32
  # Read and format data.
33
  tweets_raw = pd.read_csv("test.csv", nrows=20)
 
24
  return output
25
 
26
  # Define models to be used
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
  bert_path = "bert-base-uncased"
29
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_path)
30
  bert_model = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
31
+ tuned_model = model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device(device))
32
 
33
  # Read and format data.
34
  tweets_raw = pd.read_csv("test.csv", nrows=20)