Spaces:
Runtime error
Runtime error
import sys | |
sys.path.append(".") | |
from opensora.models.ae.videobase.dataset_videobase import VideoDataset | |
from opensora.models.ae.videobase import ( | |
VQVAEModel, | |
VQVAEConfiguration, | |
VQVAETrainer, | |
) | |
import argparse | |
from typing import Optional | |
from accelerate.utils import set_seed | |
from transformers import HfArgumentParser, TrainingArguments | |
from dataclasses import dataclass, field, asdict | |
class VQVAEArgument: | |
embedding_dim: int = (field(default=256),) | |
n_codes: int = (field(default=2048),) | |
n_hiddens: int = (field(default=240),) | |
n_res_layers: int = (field(default=4),) | |
resolution: int = (field(default=128),) | |
sequence_length: int = (field(default=16),) | |
downsample: str = (field(default="4,4,4"),) | |
no_pos_embd: bool = (True,) | |
data_path: str = field(default=None, metadata={"help": "data path"}) | |
class VQVAETrainingArgument(TrainingArguments): | |
remove_unused_columns: Optional[bool] = field( | |
default=False, | |
metadata={ | |
"help": "Remove columns not required by the model when using an nlp.Dataset." | |
}, | |
) | |
def train(args, vqvae_args: VQVAEArgument, training_args: VQVAETrainingArgument): | |
# Load Config | |
config = VQVAEConfiguration( | |
embedding_dim=vqvae_args.embedding_dim, | |
n_codes=vqvae_args.n_codes, | |
n_hiddens=vqvae_args.n_hiddens, | |
n_res_layers=vqvae_args.n_res_layers, | |
resolution=vqvae_args.resolution, | |
sequence_length=vqvae_args.sequence_length, | |
downsample=vqvae_args.downsample, | |
no_pos_embd=vqvae_args.no_pos_embd, | |
) | |
# Load Model | |
model = VQVAEModel(config) | |
# Load Dataset | |
dataset = VideoDataset( | |
args.data_path, | |
sequence_length=args.sequence_length, | |
resolution=config.resolution, | |
) | |
# Load Trainer | |
trainer = VQVAETrainer(model, training_args, train_dataset=dataset) | |
trainer.train() | |
if __name__ == "__main__": | |
parser = HfArgumentParser((VQVAEArgument, VQVAETrainingArgument)) | |
vqvae_args, training_args = parser.parse_args_into_dataclasses() | |
args = argparse.Namespace(**vars(vqvae_args), **vars(training_args)) | |
set_seed(args.seed) | |
train(args, vqvae_args, training_args) | |