|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class TextEncoder(nn.Module): |
|
def __init__(self, vocab_size, embed_dim, hidden_dim): |
|
super().__init__() |
|
self.embedding = nn.Embedding(vocab_size, embed_dim) |
|
self.transformer = nn.TransformerEncoder( |
|
nn.TransformerEncoderLayer(embed_dim, nhead=8), |
|
num_layers=6 |
|
) |
|
|
|
def forward(self, text): |
|
x = self.embedding(text) |
|
return self.transformer(x) |
|
|
|
class VideoGenerator(nn.Module): |
|
def __init__(self, latent_dim, num_frames, frame_size): |
|
super().__init__() |
|
self.latent_dim = latent_dim |
|
self.num_frames = num_frames |
|
|
|
self.generator = nn.Sequential( |
|
nn.ConvTranspose3d(latent_dim, 512, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm3d(512), |
|
nn.ReLU(), |
|
nn.ConvTranspose3d(512, 256, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm3d(256), |
|
nn.ReLU(), |
|
nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm3d(128), |
|
nn.ReLU(), |
|
nn.ConvTranspose3d(128, 3, kernel_size=4, stride=2, padding=1), |
|
nn.Tanh() |
|
) |
|
|
|
def forward(self, z): |
|
return self.generator(z) |
|
|
|
class Text2VideoModel(nn.Module): |
|
def __init__(self, vocab_size, embed_dim, latent_dim, num_frames, frame_size): |
|
super().__init__() |
|
self.text_encoder = TextEncoder(vocab_size, embed_dim, hidden_dim=512) |
|
self.video_generator = VideoGenerator(latent_dim, num_frames, frame_size) |
|
self.latent_mapper = nn.Linear(embed_dim, latent_dim * num_frames) |
|
|
|
def forward(self, text): |
|
text_features = self.text_encoder(text) |
|
latent_vector = self.latent_mapper(text_features.mean(dim=1)) |
|
latent_video = latent_vector.view(-1, self.video_generator.latent_dim, 1, 1, 1) |
|
generated_video = self.video_generator(latent_video) |
|
return generated_video |
|
|