Spaces:
Build error
Build error
!pip install transformers | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
# Load the pretrained BERT model and tokenizer | |
model_name = 'bert-base-uncased' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6) | |
# Define the training data and labels | |
train_texts = [...] # List of training text inputs | |
train_labels = [...] # List of training labels (one-hot encoded) | |
# Define a function to preprocess the text input | |
def preprocess(text): | |
inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt') | |
return inputs['input_ids'], inputs['attention_mask'] | |
# Define a function to encode the labels as one-hot vectors | |
def encode_labels(labels): | |
return torch.tensor(labels, dtype=torch.float) | |
# Define the training data and labels as PyTorch tensors | |
train_inputs = [preprocess(text) for text in train_texts] | |
train_labels = encode_labels(train_labels) | |
# Define the training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
num_train_epochs=3, | |
per_device_train_batch_size=32, | |
per_device_eval_batch_size=64, | |
warmup_steps=500, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=10 | |
) | |
# Define the trainer object | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=list(zip(train_inputs, train_labels)) | |
) | |
# Train the model | |
trainer.train() | |
# Define a function to classify a text input | |
def classify(text): | |
input_ids, attention_mask = preprocess(text) | |
with torch.no_grad(): | |
logits = model(input_ids, attention_mask=attention_mask).logits | |
preds = torch.sigmoid(logits).squeeze().tolist() | |
return {labels[i]: preds[i] for i in range(len(labels))} | |
# Example usage | |
text = "You are a stupid idiot" | |
preds = classify(text) | |
print(preds) # Output: {'toxic': 0.98, 'severe_toxic': 0.03, 'obscene': 0.94, 'threat': 0.01, 'insult': 0.88, 'identity_hate': 0.02} |