import torch import torch.nn as nn import torch.nn.functional as F class TextEncoder(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(embed_dim, nhead=8), num_layers=6 ) def forward(self, text): x = self.embedding(text) return self.transformer(x) class VideoGenerator(nn.Module): def __init__(self, latent_dim, num_frames, frame_size): super().__init__() self.latent_dim = latent_dim self.num_frames = num_frames self.generator = nn.Sequential( nn.ConvTranspose3d(latent_dim, 512, kernel_size=4, stride=2, padding=1), nn.BatchNorm3d(512), nn.ReLU(), nn.ConvTranspose3d(512, 256, kernel_size=4, stride=2, padding=1), nn.BatchNorm3d(256), nn.ReLU(), nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1), nn.BatchNorm3d(128), nn.ReLU(), nn.ConvTranspose3d(128, 3, kernel_size=4, stride=2, padding=1), nn.Tanh() ) def forward(self, z): return self.generator(z) class Text2VideoModel(nn.Module): def __init__(self, vocab_size, embed_dim, latent_dim, num_frames, frame_size): super().__init__() self.text_encoder = TextEncoder(vocab_size, embed_dim, hidden_dim=512) self.video_generator = VideoGenerator(latent_dim, num_frames, frame_size) self.latent_mapper = nn.Linear(embed_dim, latent_dim * num_frames) def forward(self, text): text_features = self.text_encoder(text) latent_vector = self.latent_mapper(text_features.mean(dim=1)) latent_video = latent_vector.view(-1, self.video_generator.latent_dim, 1, 1, 1) generated_video = self.video_generator(latent_video) return generated_video