Toxiclassifier / app.py
dp92's picture
Update app.py
3630984
raw
history blame
1.29 kB
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load the pre-trained BERT model and tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6)
# Load the data
data = pd.read_csv("toxic_comments.csv")
# Define the function to preprocess the text
def preprocess(text):
inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors="pt")
return inputs["input_ids"], inputs["attention_mask"]
# Define the function to classify a text input
def classify(text):
input_ids, attention_mask = preprocess(text)
with torch.no_grad():
logits = model(input_ids, attention_mask=attention_mask).logits
preds = torch.sigmoid(logits).squeeze().tolist()
return {labels[i]: preds[i] for i in range(len(labels))}
# Define the labels
labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
# Classify the comments and print the results
for i, row in data.iterrows():
text = row["comment_text"]
preds = classify(text)
print("Comment: ", text)
print("Predictions: ", preds)
print("Labels: ", row[labels].to_dict())