import torch | |
from transformers import Trainer, TrainingArguments | |
from datasets import load_dataset | |
def train_model(): | |
training_args = TrainingArguments( | |
output_dir="./checkpoints", | |
num_train_epochs=100, | |
per_device_train_batch_size=4, | |
gradient_accumulation_steps=4, | |
learning_rate=1e-4, | |
fp16=True, | |
save_steps=500, | |
) | |
dataset = load_dataset("dance_videos_dataset") | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, | |
) | |
trainer.train() | |