Spaces:
Build error
Build error
File size: 2,043 Bytes
054e256 4c30733 054e256 4c30733 054e256 4c30733 054e256 4c30733 230173f 4c30733 054e256 4c30733 054e256 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
!pip install transformers
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
# Load the pretrained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6)
# Define the training data and labels
train_texts = [...] # List of training text inputs
train_labels = [...] # List of training labels (one-hot encoded)
# Define a function to preprocess the text input
def preprocess(text):
inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
return inputs['input_ids'], inputs['attention_mask']
# Define a function to encode the labels as one-hot vectors
def encode_labels(labels):
return torch.tensor(labels, dtype=torch.float)
# Define the training data and labels as PyTorch tensors
train_inputs = [preprocess(text) for text in train_texts]
train_labels = encode_labels(train_labels)
# Define the training arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10
)
# Define the trainer object
trainer = Trainer(
model=model,
args=training_args,
train_dataset=list(zip(train_inputs, train_labels))
)
# Train the model
trainer.train()
# Define a function to classify a text input
def classify(text):
input_ids, attention_mask = preprocess(text)
with torch.no_grad():
logits = model(input_ids, attention_mask=attention_mask).logits
preds = torch.sigmoid(logits).squeeze().tolist()
return {labels[i]: preds[i] for i in range(len(labels))}
# Example usage
text = "You are a stupid idiot"
preds = classify(text)
print(preds) # Output: {'toxic': 0.98, 'severe_toxic': 0.03, 'obscene': 0.94, 'threat': 0.01, 'insult': 0.88, 'identity_hate': 0.02} |