File size: 569 Bytes
9271111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from transformers import Trainer, TrainingArguments
from datasets import load_dataset

def train_model():
    training_args = TrainingArguments(
        output_dir="./checkpoints",
        num_train_epochs=100,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=1e-4,
        fp16=True,
        save_steps=500,
    )
    
    dataset = load_dataset("dance_videos_dataset")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
    )
    
    trainer.train()