|
from models.text2video_model import Text2VideoModel |
|
from training.trainer import Text2VideoTrainer |
|
from config.model_config import CONFIG |
|
import torch |
|
|
|
def main(): |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
model = Text2VideoModel( |
|
vocab_size=CONFIG['vocab_size'], |
|
embed_dim=CONFIG['embed_dim'], |
|
latent_dim=CONFIG['latent_dim'], |
|
num_frames=CONFIG['num_frames'], |
|
frame_size=CONFIG['frame_size'] |
|
).to(device) |
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate']) |
|
trainer = Text2VideoTrainer(model, optimizer, device) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
class Text2VideoTrainer: |
|
def __init__(self, model, optimizer, device): |
|
self.model = model |
|
self.optimizer = optimizer |
|
self.device = device |
|
|
|
def train_step(self, text_batch, video_batch): |
|
self.optimizer.zero_grad() |
|
|
|
generated_video = self.model(text_batch) |
|
loss = F.mse_loss(generated_video, video_batch) |
|
|
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
return loss.item() |
|
|