File size: 706 Bytes
a220803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from ..trainer_videobase import VideoBaseTrainer
import torch.nn.functional as F
from typing import Optional
import os
import torch
from transformers.utils import WEIGHTS_NAME
import json

class VQVAETrainer(VideoBaseTrainer):

    def compute_loss(self, model, inputs, return_outputs=False):
        model = model.module
        x = inputs.get("video")
        x = x / 2
        z = model.pre_vq_conv(model.encoder(x))
        vq_output = model.codebook(z)
        x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"]))
        recon_loss = F.mse_loss(x_recon, x) / 0.06
        commitment_loss = vq_output['commitment_loss']
        loss = recon_loss + commitment_loss
        return loss