milestone-2 / app.py
nppmatt's picture
add padding to xlm
67d3612
raw
history blame
2.09 kB
import streamlit as st
import plotly.express as px
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
option = st.selectbox("Select a toxicity analysis model:", ("RoBERTa", "DistilBERT", "XLM-RoBERTa"))
defaultTxt = "I hate you cancerous insects so much"
txt = st.text_area("Text to analyze", defaultTxt)
st.button("Submit Text")
# Load tokenizer and model weights, try to default to RoBERTa.
# Huggingface does not support Python 3.10 match statements and I'm too lazy to implement an equivalent.
if (option == "RoBERTa"):
tokenizerPath = "s-nlp/roberta_toxicity_classifier"
modelPath = "s-nlp/roberta_toxicity_classifier"
neutralIndex = 0
toxicIndex = 1
elif (option == "DistilBERT"):
tokenizerPath = "citizenlab/distilbert-base-multilingual-cased-toxicity"
modelPath = "citizenlab/distilbert-base-multilingual-cased-toxicity"
neutralIndex = 1
toxicIndex = 0
elif (option == "XLM-RoBERTa"):
tokenizerPath = "unitary/multilingual-toxic-xlm-roberta"
modelPath = "unitary/multilingual-toxic-xlm-roberta"
neutralIndex = 1
toxicIndex = 0
else:
tokenizerPath = "s-nlp/roberta_toxicity_classifier"
modelPath = "s-nlp/roberta_toxicity_classifier"
neutralIndex = 0
toxicIndex = 1
tokenizer = AutoTokenizer.from_pretrained(tokenizerPath)
model = AutoModelForSequenceClassification.from_pretrained(modelPath)
# run encoding through model to get classification output
# RoBERTA: [0]: neutral, [1]: toxic
encoding = tokenizer.encode(txt, return_tensors='pt')
result = model(encoding)
result
pad = (0, 1)
nn.functional.pad(result.logits, pad, "constant", 0)
result
# transform logit to get probabilities
prediction = nn.functional.softmax(result.logits, dim=-1)
prediction
#neutralProb = prediction.data[0][neutralIndex]
#toxicProb = prediction.data[0][toxicIndex]
# Expected returns from RoBERTa on default text:
# Neutral: 0.0052
# Toxic: 0.9948
st.write("Classification Probabilities")
#st.write(f"{neutralProb:.4f} - NEUTRAL")
#st.write(f"{toxicProb:.4f} - TOXIC")