ChromiumPlutoniumAI commited on
Commit
38bdd50
·
verified ·
1 Parent(s): 0150833

Create text2video_model.py

Browse files
Files changed (1) hide show
  1. text2video_model.py +53 -0
text2video_model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class TextEncoder(nn.Module):
6
+ def __init__(self, vocab_size, embed_dim, hidden_dim):
7
+ super().__init__()
8
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
9
+ self.transformer = nn.TransformerEncoder(
10
+ nn.TransformerEncoderLayer(embed_dim, nhead=8),
11
+ num_layers=6
12
+ )
13
+
14
+ def forward(self, text):
15
+ x = self.embedding(text)
16
+ return self.transformer(x)
17
+
18
+ class VideoGenerator(nn.Module):
19
+ def __init__(self, latent_dim, num_frames, frame_size):
20
+ super().__init__()
21
+ self.latent_dim = latent_dim
22
+ self.num_frames = num_frames
23
+
24
+ self.generator = nn.Sequential(
25
+ nn.ConvTranspose3d(latent_dim, 512, kernel_size=4, stride=2, padding=1),
26
+ nn.BatchNorm3d(512),
27
+ nn.ReLU(),
28
+ nn.ConvTranspose3d(512, 256, kernel_size=4, stride=2, padding=1),
29
+ nn.BatchNorm3d(256),
30
+ nn.ReLU(),
31
+ nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1),
32
+ nn.BatchNorm3d(128),
33
+ nn.ReLU(),
34
+ nn.ConvTranspose3d(128, 3, kernel_size=4, stride=2, padding=1),
35
+ nn.Tanh()
36
+ )
37
+
38
+ def forward(self, z):
39
+ return self.generator(z)
40
+
41
+ class Text2VideoModel(nn.Module):
42
+ def __init__(self, vocab_size, embed_dim, latent_dim, num_frames, frame_size):
43
+ super().__init__()
44
+ self.text_encoder = TextEncoder(vocab_size, embed_dim, hidden_dim=512)
45
+ self.video_generator = VideoGenerator(latent_dim, num_frames, frame_size)
46
+ self.latent_mapper = nn.Linear(embed_dim, latent_dim * num_frames)
47
+
48
+ def forward(self, text):
49
+ text_features = self.text_encoder(text)
50
+ latent_vector = self.latent_mapper(text_features.mean(dim=1))
51
+ latent_video = latent_vector.view(-1, self.video_generator.latent_dim, 1, 1, 1)
52
+ generated_video = self.video_generator(latent_video)
53
+ return generated_video