Spaces:
Build error
Build error
!pip install transformers torch | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# Load the pretrained BERT model and tokenizer | |
model_name = 'bert-base-uncased' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6) | |
# Define the labels and their corresponding indices | |
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] | |
label2id = {label: i for i, label in enumerate(labels)} | |
# Define a function to preprocess the text input | |
def preprocess(text): | |
inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt') | |
return inputs['input_ids'], inputs['attention_mask'] | |
# Define a function to classify a text input | |
def classify(text): | |
input_ids, attention_mask = preprocess(text) | |
with torch.no_grad(): | |
logits = model(input_ids, attention_mask=attention_mask).logits | |
preds = torch.sigmoid(logits) > 0.5 | |
return [labels[i] for i, pred in enumerate(preds.squeeze().tolist()) if pred] | |
# Example usage | |
text = "You are a stupid idiot" | |
preds = classify(text) | |
print(preds) # Output: ['toxic', 'insult'] |