Toxiclassifier / app.py
dp92's picture
Update app.py
4c30733
raw
history blame
1.21 kB
!pip install transformers torch
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load the pretrained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6)
# Define the labels and their corresponding indices
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
label2id = {label: i for i, label in enumerate(labels)}
# Define a function to preprocess the text input
def preprocess(text):
inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
return inputs['input_ids'], inputs['attention_mask']
# Define a function to classify a text input
def classify(text):
input_ids, attention_mask = preprocess(text)
with torch.no_grad():
logits = model(input_ids, attention_mask=attention_mask).logits
preds = torch.sigmoid(logits) > 0.5
return [labels[i] for i, pred in enumerate(preds.squeeze().tolist()) if pred]
# Example usage
text = "You are a stupid idiot"
preds = classify(text)
print(preds) # Output: ['toxic', 'insult']