Spaces:
Running
on
Zero
Running
on
Zero
Optimize memory usage
Browse files
app.py
CHANGED
@@ -208,32 +208,19 @@ seed_everything(seed, workers=True)
|
|
208 |
with open("ThinkSound/configs/model_configs/thinksound.json") as f:
|
209 |
model_config = json.load(f)
|
210 |
|
211 |
-
|
|
|
|
|
|
|
212 |
|
213 |
## speed by torch.compile
|
214 |
if args.compile:
|
215 |
-
|
216 |
-
|
217 |
-
if args.pretrained_ckpt_path:
|
218 |
-
copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
|
219 |
-
|
220 |
-
if args.remove_pretransform_weight_norm == "pre_load":
|
221 |
-
remove_weight_norm_from_model(model.pretransform)
|
222 |
|
223 |
|
224 |
load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
|
225 |
# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
|
226 |
-
|
227 |
-
|
228 |
-
# Remove weight_norm from the pretransform if specified
|
229 |
-
if args.remove_pretransform_weight_norm == "post_load":
|
230 |
-
remove_weight_norm_from_model(model.pretransform)
|
231 |
-
ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound.ckpt",repo_type="model")
|
232 |
-
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
233 |
-
# 加载模型权重时根据设备选择map_location
|
234 |
-
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
235 |
-
|
236 |
-
training_wrapper.to("cuda")
|
237 |
|
238 |
def get_video_duration(video_path):
|
239 |
video = VideoFileClip(video_path)
|
@@ -276,36 +263,36 @@ def synthesize_video_with_audio(video_file, caption, cot):
|
|
276 |
sync_seq_len = preprocessed_data['sync_features'].shape[0]
|
277 |
clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
|
278 |
latent_seq_len = (int)(194/9*duration_sec)
|
279 |
-
|
280 |
|
281 |
metadata = [preprocessed_data]
|
282 |
|
283 |
batch_size = 1
|
284 |
length = latent_seq_len
|
285 |
with torch.amp.autocast(device):
|
286 |
-
conditioning =
|
287 |
|
288 |
video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
|
289 |
-
conditioning['metaclip_features'][~video_exist] =
|
290 |
-
conditioning['sync_features'][~video_exist] =
|
291 |
|
292 |
yield "⏳ Inferring…", None
|
293 |
|
294 |
-
cond_inputs =
|
295 |
-
noise = torch.randn([batch_size,
|
296 |
with torch.amp.autocast(device):
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
elif training_wrapper.diffusion_objective == "rectified_flow":
|
301 |
import time
|
302 |
start_time = time.time()
|
303 |
-
fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
304 |
end_time = time.time()
|
305 |
execution_time = end_time - start_time
|
306 |
-
print(f"
|
307 |
-
|
308 |
-
|
|
|
309 |
|
310 |
audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
311 |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
|
|
|
208 |
with open("ThinkSound/configs/model_configs/thinksound.json") as f:
|
209 |
model_config = json.load(f)
|
210 |
|
211 |
+
diffusion_model = create_model_from_config(model_config)
|
212 |
+
ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound_light.ckpt",repo_type="model")
|
213 |
+
diffusion_model.load_state_dict(torch.load(ckpt_path))
|
214 |
+
diffusion_model.to(device)
|
215 |
|
216 |
## speed by torch.compile
|
217 |
if args.compile:
|
218 |
+
diffusion_model = torch.compile(diffusion_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
|
221 |
load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
|
222 |
# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
|
223 |
+
diffusion_model.pretransform.load_state_dict(load_vae_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
def get_video_duration(video_path):
|
226 |
video = VideoFileClip(video_path)
|
|
|
263 |
sync_seq_len = preprocessed_data['sync_features'].shape[0]
|
264 |
clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
|
265 |
latent_seq_len = (int)(194/9*duration_sec)
|
266 |
+
diffusion_model.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
|
267 |
|
268 |
metadata = [preprocessed_data]
|
269 |
|
270 |
batch_size = 1
|
271 |
length = latent_seq_len
|
272 |
with torch.amp.autocast(device):
|
273 |
+
conditioning = diffusion_model.conditioner(metadata, device)
|
274 |
|
275 |
video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
|
276 |
+
conditioning['metaclip_features'][~video_exist] = diffusion_model.model.model.empty_clip_feat
|
277 |
+
conditioning['sync_features'][~video_exist] = diffusion_model.model.model.empty_sync_feat
|
278 |
|
279 |
yield "⏳ Inferring…", None
|
280 |
|
281 |
+
cond_inputs = diffusion_model.get_conditioning_inputs(conditioning)
|
282 |
+
noise = torch.randn([batch_size, diffusion_model.io_channels, length]).to(device)
|
283 |
with torch.amp.autocast(device):
|
284 |
+
if diffusion_model.diffusion_objective == "v":
|
285 |
+
fakes = sample(diffusion_model.model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
286 |
+
elif diffusion_model.diffusion_objective == "rectified_flow":
|
|
|
287 |
import time
|
288 |
start_time = time.time()
|
289 |
+
fakes = sample_discrete_euler(diffusion_model.model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
290 |
end_time = time.time()
|
291 |
execution_time = end_time - start_time
|
292 |
+
print(f"execution_time: {execution_time:.2f} 秒")
|
293 |
+
|
294 |
+
if diffusion_model.pretransform is not None:
|
295 |
+
fakes = diffusion_model.pretransform.decode(fakes)
|
296 |
|
297 |
audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
298 |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
|