File size: 2,038 Bytes
38bdd50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class TextEncoder(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(embed_dim, nhead=8),
num_layers=6
)
def forward(self, text):
x = self.embedding(text)
return self.transformer(x)
class VideoGenerator(nn.Module):
def __init__(self, latent_dim, num_frames, frame_size):
super().__init__()
self.latent_dim = latent_dim
self.num_frames = num_frames
self.generator = nn.Sequential(
nn.ConvTranspose3d(latent_dim, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(512),
nn.ReLU(),
nn.ConvTranspose3d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(256),
nn.ReLU(),
nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(128),
nn.ReLU(),
nn.ConvTranspose3d(128, 3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, z):
return self.generator(z)
class Text2VideoModel(nn.Module):
def __init__(self, vocab_size, embed_dim, latent_dim, num_frames, frame_size):
super().__init__()
self.text_encoder = TextEncoder(vocab_size, embed_dim, hidden_dim=512)
self.video_generator = VideoGenerator(latent_dim, num_frames, frame_size)
self.latent_mapper = nn.Linear(embed_dim, latent_dim * num_frames)
def forward(self, text):
text_features = self.text_encoder(text)
latent_vector = self.latent_mapper(text_features.mean(dim=1))
latent_video = latent_vector.view(-1, self.video_generator.latent_dim, 1, 1, 1)
generated_video = self.video_generator(latent_video)
return generated_video
|