Spaces:
Runtime error
Runtime error
from ..trainer_videobase import VideoBaseTrainer | |
import torch.nn.functional as F | |
from typing import Optional | |
import os | |
import torch | |
from transformers.utils import WEIGHTS_NAME | |
import json | |
class VQVAETrainer(VideoBaseTrainer): | |
def compute_loss(self, model, inputs, return_outputs=False): | |
model = model.module | |
x = inputs.get("video") | |
x = x / 2 | |
z = model.pre_vq_conv(model.encoder(x)) | |
vq_output = model.codebook(z) | |
x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"])) | |
recon_loss = F.mse_loss(x_recon, x) / 0.06 | |
commitment_loss = vq_output['commitment_loss'] | |
loss = recon_loss + commitment_loss | |
return loss | |