Spaces:
Runtime error
Runtime error
import sys | |
sys.path.append(".") | |
from opensora.models.ae.videobase.dataset_videobase import VideoDataset | |
from opensora.models.ae.videobase import ( | |
VQVAEModel, | |
VQVAEConfiguration, | |
VQVAETrainer, | |
) | |
import argparse | |
from typing import Optional | |
from accelerate.utils import set_seed | |
from transformers import HfArgumentParser, TrainingArguments | |
from dataclasses import dataclass, field, asdict | |
class VQVAEArgument: | |
embedding_dim: int = field(default=256), | |
n_codes: int = field(default=2048), | |
n_hiddens: int = field(default=240), | |
n_res_layers: int = field(default=4), | |
resolution: int = field(default=128), | |
sequence_length: int = field(default=16), | |
downsample: str = field(default="4,4,4"), | |
no_pos_embd: bool = True, | |
data_path: str = field(default=None, metadata={"help": "data path"}) | |
class VQVAETrainingArgument(TrainingArguments): | |
remove_unused_columns: Optional[bool] = field( | |
default=False, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} | |
) | |
def train(args, vqvae_args, training_args): | |
# Load Config | |
config = VQVAEConfiguration(**asdict(vqvae_args)) | |
# Load Model | |
model = VQVAEModel(config) | |
# Load Dataset | |
dataset = VideoDataset(args.data_path, sequence_length=args.sequence_length, resolution=config.resolution) | |
# Load Trainer | |
trainer = VQVAETrainer(model, training_args, train_dataset=dataset) | |
trainer.train() | |
if __name__ == "__main__": | |
parser = HfArgumentParser((VQVAEArgument, VQVAETrainingArgument)) | |
vqvae_args, training_args = parser.parse_args_into_dataclasses() | |
args = argparse.Namespace(**vars(vqvae_args), **vars(training_args)) | |
set_seed(args.seed) | |
train(args, vqvae_args, training_args) | |