Spaces:
Runtime error
Runtime error
Rex Cheng
commited on
Commit
β’
164c335
1
Parent(s):
627e0b8
speed up inference
Browse files- app.py +2 -1
- mmaudio/eval_utils.py +20 -17
- mmaudio/ext/autoencoder/autoencoder.py +5 -1
- mmaudio/model/utils/features_utils.py +7 -5
app.py
CHANGED
@@ -48,7 +48,8 @@ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
|
48 |
synchformer_ckpt=model.synchformer_ckpt,
|
49 |
enable_conditions=True,
|
50 |
mode=model.mode,
|
51 |
-
bigvgan_vocoder_ckpt=model.bigvgan_16k_path
|
|
|
52 |
feature_utils = feature_utils.to(device, dtype).eval()
|
53 |
|
54 |
return net, feature_utils, seq_cfg
|
|
|
48 |
synchformer_ckpt=model.synchformer_ckpt,
|
49 |
enable_conditions=True,
|
50 |
mode=model.mode,
|
51 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
52 |
+
need_vae_encoder=False)
|
53 |
feature_utils = feature_utils.to(device, dtype).eval()
|
54 |
|
55 |
return net, feature_utils, seq_cfg
|
mmaudio/eval_utils.py
CHANGED
@@ -76,29 +76,37 @@ all_model_cfg: dict[str, ModelConfig] = {
|
|
76 |
}
|
77 |
|
78 |
|
79 |
-
def generate(
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
89 |
device = feature_utils.device
|
90 |
dtype = feature_utils.dtype
|
91 |
|
92 |
bs = len(text)
|
93 |
if clip_video is not None:
|
94 |
clip_video = clip_video.to(device, dtype, non_blocking=True)
|
95 |
-
clip_features = feature_utils.encode_video_with_clip(clip_video,
|
|
|
|
|
96 |
else:
|
97 |
clip_features = net.get_empty_clip_sequence(bs)
|
98 |
|
99 |
if sync_video is not None:
|
100 |
sync_video = sync_video.to(device, dtype, non_blocking=True)
|
101 |
-
sync_features = feature_utils.encode_video_with_sync(sync_video,
|
|
|
|
|
102 |
else:
|
103 |
sync_features = net.get_empty_sync_sequence(bs)
|
104 |
|
@@ -185,14 +193,9 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor
|
|
185 |
data_chunk = reader.pop_chunks()
|
186 |
clip_chunk = data_chunk[0]
|
187 |
sync_chunk = data_chunk[1]
|
188 |
-
print('clip', clip_chunk.shape, clip_chunk.dtype, clip_chunk.max())
|
189 |
-
print('sync', sync_chunk.shape, sync_chunk.dtype, sync_chunk.max())
|
190 |
assert clip_chunk is not None
|
191 |
assert sync_chunk is not None
|
192 |
|
193 |
-
for i in range(reader.num_out_streams):
|
194 |
-
print(reader.get_out_stream_info(i))
|
195 |
-
|
196 |
clip_frames = clip_transform(clip_chunk)
|
197 |
sync_frames = sync_transform(sync_chunk)
|
198 |
|
|
|
76 |
}
|
77 |
|
78 |
|
79 |
+
def generate(
|
80 |
+
clip_video: Optional[torch.Tensor],
|
81 |
+
sync_video: Optional[torch.Tensor],
|
82 |
+
text: Optional[list[str]],
|
83 |
+
*,
|
84 |
+
negative_text: Optional[list[str]] = None,
|
85 |
+
feature_utils: FeaturesUtils,
|
86 |
+
net: MMAudio,
|
87 |
+
fm: FlowMatching,
|
88 |
+
rng: torch.Generator,
|
89 |
+
cfg_strength: float,
|
90 |
+
clip_batch_size_multiplier: int = 40,
|
91 |
+
sync_batch_size_multiplier: int = 40,
|
92 |
+
) -> torch.Tensor:
|
93 |
device = feature_utils.device
|
94 |
dtype = feature_utils.dtype
|
95 |
|
96 |
bs = len(text)
|
97 |
if clip_video is not None:
|
98 |
clip_video = clip_video.to(device, dtype, non_blocking=True)
|
99 |
+
clip_features = feature_utils.encode_video_with_clip(clip_video,
|
100 |
+
batch_size=bs *
|
101 |
+
clip_batch_size_multiplier)
|
102 |
else:
|
103 |
clip_features = net.get_empty_clip_sequence(bs)
|
104 |
|
105 |
if sync_video is not None:
|
106 |
sync_video = sync_video.to(device, dtype, non_blocking=True)
|
107 |
+
sync_features = feature_utils.encode_video_with_sync(sync_video,
|
108 |
+
batch_size=bs *
|
109 |
+
sync_batch_size_multiplier)
|
110 |
else:
|
111 |
sync_features = net.get_empty_sync_sequence(bs)
|
112 |
|
|
|
193 |
data_chunk = reader.pop_chunks()
|
194 |
clip_chunk = data_chunk[0]
|
195 |
sync_chunk = data_chunk[1]
|
|
|
|
|
196 |
assert clip_chunk is not None
|
197 |
assert sync_chunk is not None
|
198 |
|
|
|
|
|
|
|
199 |
clip_frames = clip_transform(clip_chunk)
|
200 |
sync_frames = sync_transform(sync_chunk)
|
201 |
|
mmaudio/ext/autoencoder/autoencoder.py
CHANGED
@@ -15,7 +15,8 @@ class AutoEncoderModule(nn.Module):
|
|
15 |
*,
|
16 |
vae_ckpt_path,
|
17 |
vocoder_ckpt_path: Optional[str] = None,
|
18 |
-
mode: Literal['16k', '44k']
|
|
|
19 |
super().__init__()
|
20 |
self.vae: VAE = get_my_vae(mode).eval()
|
21 |
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
|
@@ -35,6 +36,9 @@ class AutoEncoderModule(nn.Module):
|
|
35 |
for param in self.parameters():
|
36 |
param.requires_grad = False
|
37 |
|
|
|
|
|
|
|
38 |
@torch.inference_mode()
|
39 |
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
|
40 |
return self.vae.encode(x)
|
|
|
15 |
*,
|
16 |
vae_ckpt_path,
|
17 |
vocoder_ckpt_path: Optional[str] = None,
|
18 |
+
mode: Literal['16k', '44k'],
|
19 |
+
need_vae_encoder: bool = True):
|
20 |
super().__init__()
|
21 |
self.vae: VAE = get_my_vae(mode).eval()
|
22 |
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
|
|
|
36 |
for param in self.parameters():
|
37 |
param.requires_grad = False
|
38 |
|
39 |
+
if not need_vae_encoder:
|
40 |
+
del self.vae.encoder
|
41 |
+
|
42 |
@torch.inference_mode()
|
43 |
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
|
44 |
return self.vae.encode(x)
|
mmaudio/model/utils/features_utils.py
CHANGED
@@ -41,6 +41,7 @@ class FeaturesUtils(nn.Module):
|
|
41 |
synchformer_ckpt: Optional[str] = None,
|
42 |
enable_conditions: bool = True,
|
43 |
mode=Literal['16k', '44k'],
|
|
|
44 |
):
|
45 |
super().__init__()
|
46 |
|
@@ -64,19 +65,18 @@ class FeaturesUtils(nn.Module):
|
|
64 |
if tod_vae_ckpt is not None:
|
65 |
self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
|
66 |
vocoder_ckpt_path=bigvgan_vocoder_ckpt,
|
67 |
-
mode=mode
|
|
|
68 |
else:
|
69 |
self.tod = None
|
70 |
self.mel_converter = MelConverter()
|
71 |
|
72 |
def compile(self):
|
73 |
if self.clip_model is not None:
|
74 |
-
self.encode_video_with_clip = torch.compile(self.encode_video_with_clip)
|
75 |
self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
|
76 |
self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
|
77 |
if self.synchformer is not None:
|
78 |
self.synchformer = torch.compile(self.synchformer)
|
79 |
-
self.tod.encode = torch.compile(self.tod.encode)
|
80 |
self.decode = torch.compile(self.decode)
|
81 |
self.vocode = torch.compile(self.vocode)
|
82 |
|
@@ -121,9 +121,11 @@ class FeaturesUtils(nn.Module):
|
|
121 |
outputs = []
|
122 |
if batch_size < 0:
|
123 |
batch_size = b
|
124 |
-
|
|
|
125 |
outputs.append(self.synchformer(x[i:i + batch_size]))
|
126 |
-
x = torch.cat(outputs, dim=0)
|
|
|
127 |
return x
|
128 |
|
129 |
@torch.inference_mode()
|
|
|
41 |
synchformer_ckpt: Optional[str] = None,
|
42 |
enable_conditions: bool = True,
|
43 |
mode=Literal['16k', '44k'],
|
44 |
+
need_vae_encoder: bool = True,
|
45 |
):
|
46 |
super().__init__()
|
47 |
|
|
|
65 |
if tod_vae_ckpt is not None:
|
66 |
self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
|
67 |
vocoder_ckpt_path=bigvgan_vocoder_ckpt,
|
68 |
+
mode=mode,
|
69 |
+
need_vae_encoder=need_vae_encoder)
|
70 |
else:
|
71 |
self.tod = None
|
72 |
self.mel_converter = MelConverter()
|
73 |
|
74 |
def compile(self):
|
75 |
if self.clip_model is not None:
|
|
|
76 |
self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
|
77 |
self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
|
78 |
if self.synchformer is not None:
|
79 |
self.synchformer = torch.compile(self.synchformer)
|
|
|
80 |
self.decode = torch.compile(self.decode)
|
81 |
self.vocode = torch.compile(self.vocode)
|
82 |
|
|
|
121 |
outputs = []
|
122 |
if batch_size < 0:
|
123 |
batch_size = b
|
124 |
+
x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
|
125 |
+
for i in range(0, b * num_segments, batch_size):
|
126 |
outputs.append(self.synchformer(x[i:i + batch_size]))
|
127 |
+
x = torch.cat(outputs, dim=0)
|
128 |
+
x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
|
129 |
return x
|
130 |
|
131 |
@torch.inference_mode()
|