lym0302 commited on
Commit
9423df5
·
1 Parent(s): 0321bb5
third_party/MMAudio/mmaudio/ext/autoencoder/vae.py CHANGED
@@ -75,11 +75,16 @@ class VAE(nn.Module):
75
  super().__init__()
76
 
77
  if data_dim == 80:
78
- self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
79
- self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
 
 
 
80
  elif data_dim == 128:
81
- self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
82
- self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
 
 
83
 
84
  self.data_mean = self.data_mean.view(1, -1, 1)
85
  self.data_std = self.data_std.view(1, -1, 1)
 
75
  super().__init__()
76
 
77
  if data_dim == 80:
78
+ # self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
79
+ # self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
80
+ self.register_buffer("data_mean", torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
81
+ self.register_buffer("data_std", torch.tensor(DATA_STD_80D, dtype=torch.float32))
82
+
83
  elif data_dim == 128:
84
+ # self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
85
+ # self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
86
+ self.register_buffer("data_mean", torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
87
+ self.register_buffer("data_std", torch.tensor(DATA_STD_128D, dtype=torch.float32))
88
 
89
  self.data_mean = self.data_mean.view(1, -1, 1)
90
  self.data_std = self.data_std.view(1, -1, 1)
third_party/MMAudio/mmaudio/model/embeddings.py CHANGED
@@ -21,10 +21,17 @@ class TimestepEmbedder(nn.Module):
21
  assert dim % 2 == 0, 'dim must be even.'
22
 
23
  with torch.autocast('cuda', enabled=False):
24
- self.freqs = nn.Buffer(
 
 
 
 
 
25
  1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
26
- frequency_embedding_size)),
27
- persistent=False)
 
 
28
  freq_scale = 10000 / max_period
29
  self.freqs = freq_scale * self.freqs
30
 
 
21
  assert dim % 2 == 0, 'dim must be even.'
22
 
23
  with torch.autocast('cuda', enabled=False):
24
+ # self.freqs = nn.Buffer(
25
+ # 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
26
+ # frequency_embedding_size)),
27
+ # persistent=False)
28
+ self.register_buffer(
29
+ "freqs",
30
  1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
31
+ frequency_embedding_size)),
32
+ persistent=False
33
+ )
34
+
35
  freq_scale = 10000 / max_period
36
  self.freqs = freq_scale * self.freqs
37
 
third_party/MMAudio/mmaudio/model/networks.py CHANGED
@@ -166,8 +166,11 @@ class MMAudio(nn.Module):
166
  self._clip_seq_len,
167
  device=self.device)
168
 
169
- self.latent_rot = nn.Buffer(latent_rot, persistent=False)
170
- self.clip_rot = nn.Buffer(clip_rot, persistent=False)
 
 
 
171
 
172
  def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
173
  self._latent_seq_len = latent_seq_len
 
166
  self._clip_seq_len,
167
  device=self.device)
168
 
169
+ # self.latent_rot = nn.Buffer(latent_rot, persistent=False)
170
+ # self.clip_rot = nn.Buffer(clip_rot, persistent=False)
171
+ self.register_buffer("latent_rot", latent_rot, persistent=False)
172
+ self.register_buffer("clip_rot", clip_rot, persistent=False)
173
+
174
 
175
  def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
176
  self._latent_seq_len = latent_seq_len