Spaces:
Runtime error
Runtime error
import pandas as pd | |
import torch | |
import torch.nn.functional as TF | |
import streamlit as st | |
option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT")) | |
bert_path = "bert-base-uncased" | |
if (option == "BERT"): | |
tokenizer = AutoTokenizer.from_pretrained(bert_path) | |
model = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6) | |
else: | |
tweets_raw = pd.read_csv("train.csv", nrows=20) | |
label_set = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] | |
# Run encoding through model to get classification output. | |
encoding = tokenizer.encode(txt, return_tensors='pt') | |
result = model(encoding) | |
# Transform logit to get probabilities. | |
if (result.logits.size(dim=1) < 2): | |
pad = (0, 1) | |
result.logits = nn.functional.pad(result.logits, pad, "constant", 0) | |
prediction = nn.functional.softmax(result.logits, dim=-1) | |
neutralProb = prediction.data[0][neutralIndex] | |
toxicProb = prediction.data[0][toxicIndex] | |
# Write results | |
st.write("Classification Probabilities") | |
st.write(f"{neutralProb:.4f} - NEUTRAL") | |
st.write(f"{toxicProb:.4f} - TOXIC") |