File size: 2,043 Bytes
054e256
 
4c30733
054e256
4c30733
 
 
 
 
 
054e256
 
 
4c30733
 
 
 
 
 
054e256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c30733
 
 
 
230173f
 
4c30733
054e256
 
4c30733
054e256
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
!pip install transformers

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

# Load the pretrained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6)

# Define the training data and labels
train_texts = [...]  # List of training text inputs
train_labels = [...]  # List of training labels (one-hot encoded)

# Define a function to preprocess the text input
def preprocess(text):
    inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
    return inputs['input_ids'], inputs['attention_mask']

# Define a function to encode the labels as one-hot vectors
def encode_labels(labels):
    return torch.tensor(labels, dtype=torch.float)

# Define the training data and labels as PyTorch tensors
train_inputs = [preprocess(text) for text in train_texts]
train_labels = encode_labels(train_labels)

# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10
)

# Define the trainer object
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=list(zip(train_inputs, train_labels))
)

# Train the model
trainer.train()

# Define a function to classify a text input
def classify(text):
    input_ids, attention_mask = preprocess(text)
    with torch.no_grad():
        logits = model(input_ids, attention_mask=attention_mask).logits
    preds = torch.sigmoid(logits).squeeze().tolist()
    return {labels[i]: preds[i] for i in range(len(labels))}

# Example usage
text = "You are a stupid idiot"
preds = classify(text)
print(preds)  # Output: {'toxic': 0.98, 'severe_toxic': 0.03, 'obscene': 0.94, 'threat': 0.01, 'insult': 0.88, 'identity_hate': 0.02}