from models.text2video_model import Text2VideoModel from training.trainer import Text2VideoTrainer from config.model_config import CONFIG import torch def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Text2VideoModel( vocab_size=CONFIG['vocab_size'], embed_dim=CONFIG['embed_dim'], latent_dim=CONFIG['latent_dim'], num_frames=CONFIG['num_frames'], frame_size=CONFIG['frame_size'] ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate']) trainer = Text2VideoTrainer(model, optimizer, device) # Add your data loading and training loop here if __name__ == '__main__': main() class Text2VideoTrainer: def __init__(self, model, optimizer, device): self.model = model self.optimizer = optimizer self.device = device def train_step(self, text_batch, video_batch): self.optimizer.zero_grad() generated_video = self.model(text_batch) loss = F.mse_loss(generated_video, video_batch) loss.backward() self.optimizer.step() return loss.item()