ChromiumPlutoniumAI commited on
Commit
53abcc6
·
verified ·
1 Parent(s): 38bdd50

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +16 -0
trainer.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Text2VideoTrainer:
2
+ def __init__(self, model, optimizer, device):
3
+ self.model = model
4
+ self.optimizer = optimizer
5
+ self.device = device
6
+
7
+ def train_step(self, text_batch, video_batch):
8
+ self.optimizer.zero_grad()
9
+
10
+ generated_video = self.model(text_batch)
11
+ loss = F.mse_loss(generated_video, video_batch)
12
+
13
+ loss.backward()
14
+ self.optimizer.step()
15
+
16
+ return loss.item()