Spaces:
No application file
No application file
import torch | |
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer | |
# Load the data | |
train_data = ... # load your training data here | |
eval_data = ... # load your evaluation data here | |
# Define the model architecture | |
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=8) | |
# Set up the training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
num_train_epochs=3, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=64, | |
warmup_steps=500, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_first_step=True, | |
logging_steps=50, | |
save_total_limit=2, | |
save_steps=500, | |
eval_steps=500, | |
learning_rate=5e-5, | |
seed=42, | |
) | |
# Create the trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_data, | |
eval_dataset=eval_data, | |
compute_metrics=lambda pred: {'accuracy': torch.tensor(pred).argmax().item()}, | |
) | |
# Train the model | |
trainer.train() | |
# Evaluate the model | |
loss, metrics = trainer.evaluate() | |
print(f'Loss: {loss}') | |
print(f'Metrics: {metrics}') |