Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,54 +1,34 @@
|
|
1 |
-
|
2 |
import torch
|
3 |
import torch.nn.functional as TF
|
4 |
import streamlit as st
|
5 |
|
6 |
-
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
if (option == "RoBERTa"):
|
14 |
-
tokenizerPath = "s-nlp/roberta_toxicity_classifier"
|
15 |
-
modelPath = "s-nlp/roberta_toxicity_classifier"
|
16 |
-
neutralIndex = 0
|
17 |
-
toxicIndex = 1
|
18 |
-
elif (option == "DistilBERT"):
|
19 |
-
tokenizerPath = "citizenlab/distilbert-base-multilingual-cased-toxicity"
|
20 |
-
modelPath = "citizenlab/distilbert-base-multilingual-cased-toxicity"
|
21 |
-
neutralIndex = 1
|
22 |
-
toxicIndex = 0
|
23 |
-
elif (option == "XLM-RoBERTa"):
|
24 |
-
tokenizerPath = "unitary/multilingual-toxic-xlm-roberta"
|
25 |
-
modelPath = "unitary/multilingual-toxic-xlm-roberta"
|
26 |
-
neutralIndex = 1
|
27 |
-
toxicIndex = 0
|
28 |
else:
|
29 |
-
|
30 |
-
modelPath = "s-nlp/roberta_toxicity_classifier"
|
31 |
-
neutralIndex = 0
|
32 |
-
toxicIndex = 1
|
33 |
-
|
34 |
-
tokenizer = AutoTokenizer.from_pretrained(tokenizerPath)
|
35 |
-
model = AutoModelForSequenceClassification.from_pretrained(modelPath)
|
36 |
|
37 |
-
|
38 |
-
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
39 |
-
input_ids = tokenizer(input_text)
|
40 |
|
41 |
-
batch = tokenizer(X_train, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
labels = torch.argmax(predictions, dim=1)
|
47 |
-
labels = [model.config.id2label[label_id] for label_id in labels.tolist()]
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
import torch
|
3 |
import torch.nn.functional as TF
|
4 |
import streamlit as st
|
5 |
|
6 |
+
option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT"))
|
7 |
|
8 |
+
bert_path = "bert-base-uncased"
|
9 |
+
if (option == "BERT"):
|
10 |
+
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
12 |
+
model = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
else:
|
14 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
tweets_raw = pd.read_csv("train.csv", nrows=20)
|
|
|
|
|
17 |
|
|
|
18 |
|
19 |
+
# Run encoding through model to get classification output.
|
20 |
+
encoding = tokenizer.encode(txt, return_tensors='pt')
|
21 |
+
result = model(encoding)
|
|
|
|
|
22 |
|
23 |
+
# Transform logit to get probabilities.
|
24 |
+
if (result.logits.size(dim=1) < 2):
|
25 |
+
pad = (0, 1)
|
26 |
+
result.logits = nn.functional.pad(result.logits, pad, "constant", 0)
|
27 |
+
prediction = nn.functional.softmax(result.logits, dim=-1)
|
28 |
+
neutralProb = prediction.data[0][neutralIndex]
|
29 |
+
toxicProb = prediction.data[0][toxicIndex]
|
30 |
|
31 |
+
# Write results
|
32 |
+
st.write("Classification Probabilities")
|
33 |
+
st.write(f"{neutralProb:.4f} - NEUTRAL")
|
34 |
+
st.write(f"{toxicProb:.4f} - TOXIC")
|