ChromiumPlutoniumAI commited on
Commit
9271111
·
verified ·
1 Parent(s): 659a25c

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +23 -0
train_model.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Trainer, TrainingArguments
3
+ from datasets import load_dataset
4
+
5
+ def train_model():
6
+ training_args = TrainingArguments(
7
+ output_dir="./checkpoints",
8
+ num_train_epochs=100,
9
+ per_device_train_batch_size=4,
10
+ gradient_accumulation_steps=4,
11
+ learning_rate=1e-4,
12
+ fp16=True,
13
+ save_steps=500,
14
+ )
15
+
16
+ dataset = load_dataset("dance_videos_dataset")
17
+ trainer = Trainer(
18
+ model=model,
19
+ args=training_args,
20
+ train_dataset=dataset,
21
+ )
22
+
23
+ trainer.train()