Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -208,34 +208,37 @@ seed_everything(seed, workers=True)
|
|
208 |
with open("ThinkSound/configs/model_configs/thinksound.json") as f:
|
209 |
model_config = json.load(f)
|
210 |
|
211 |
-
|
212 |
-
|
213 |
## speed by torch.compile
|
214 |
if args.compile:
|
215 |
-
|
216 |
|
217 |
if args.pretrained_ckpt_path:
|
218 |
-
copy_state_dict(
|
219 |
|
220 |
if args.remove_pretransform_weight_norm == "pre_load":
|
221 |
-
remove_weight_norm_from_model(
|
222 |
|
223 |
|
224 |
load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
|
225 |
# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
|
226 |
-
|
227 |
|
228 |
# Remove weight_norm from the pretransform if specified
|
229 |
if args.remove_pretransform_weight_norm == "post_load":
|
230 |
-
remove_weight_norm_from_model(
|
|
|
|
|
|
|
|
|
231 |
|
232 |
-
|
233 |
|
234 |
def get_video_duration(video_path):
|
235 |
video = VideoFileClip(video_path)
|
236 |
return video.duration
|
237 |
|
238 |
-
a
|
239 |
@spaces.GPU(duration=60)
|
240 |
@torch.inference_mode()
|
241 |
@torch.no_grad()
|
@@ -273,36 +276,36 @@ def synthesize_video_with_audio(video_file, caption, cot):
|
|
273 |
sync_seq_len = preprocessed_data['sync_features'].shape[0]
|
274 |
clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
|
275 |
latent_seq_len = (int)(194/9*duration_sec)
|
276 |
-
|
277 |
|
278 |
metadata = [preprocessed_data]
|
279 |
|
280 |
batch_size = 1
|
281 |
length = latent_seq_len
|
282 |
with torch.amp.autocast(device):
|
283 |
-
conditioning =
|
284 |
|
285 |
video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
|
286 |
-
conditioning['metaclip_features'][~video_exist] =
|
287 |
-
conditioning['sync_features'][~video_exist] =
|
288 |
|
289 |
yield "⏳ Inferring…", None
|
290 |
|
291 |
-
cond_inputs =
|
292 |
-
noise = torch.randn([batch_size,
|
293 |
with torch.amp.autocast(device):
|
294 |
-
model =
|
295 |
-
if
|
296 |
fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
297 |
-
elif
|
298 |
import time
|
299 |
start_time = time.time()
|
300 |
fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
301 |
end_time = time.time()
|
302 |
execution_time = end_time - start_time
|
303 |
print(f"执行时间: {execution_time:.2f} 秒")
|
304 |
-
if
|
305 |
-
fakes =
|
306 |
|
307 |
audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
308 |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
|
|
|
208 |
with open("ThinkSound/configs/model_configs/thinksound.json") as f:
|
209 |
model_config = json.load(f)
|
210 |
|
211 |
+
model = create_model_from_config(model_config)
|
212 |
+
|
213 |
## speed by torch.compile
|
214 |
if args.compile:
|
215 |
+
model = torch.compile(model)
|
216 |
|
217 |
if args.pretrained_ckpt_path:
|
218 |
+
copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
|
219 |
|
220 |
if args.remove_pretransform_weight_norm == "pre_load":
|
221 |
+
remove_weight_norm_from_model(model.pretransform)
|
222 |
|
223 |
|
224 |
load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
|
225 |
# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
|
226 |
+
model.pretransform.load_state_dict(load_vae_state)
|
227 |
|
228 |
# Remove weight_norm from the pretransform if specified
|
229 |
if args.remove_pretransform_weight_norm == "post_load":
|
230 |
+
remove_weight_norm_from_model(model.pretransform)
|
231 |
+
ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound.ckpt",repo_type="model")
|
232 |
+
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
233 |
+
# 加载模型权重时根据设备选择map_location
|
234 |
+
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
235 |
|
236 |
+
training_wrapper.to("cuda")
|
237 |
|
238 |
def get_video_duration(video_path):
|
239 |
video = VideoFileClip(video_path)
|
240 |
return video.duration
|
241 |
|
|
|
242 |
@spaces.GPU(duration=60)
|
243 |
@torch.inference_mode()
|
244 |
@torch.no_grad()
|
|
|
276 |
sync_seq_len = preprocessed_data['sync_features'].shape[0]
|
277 |
clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
|
278 |
latent_seq_len = (int)(194/9*duration_sec)
|
279 |
+
training_wrapper.diffusion.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
|
280 |
|
281 |
metadata = [preprocessed_data]
|
282 |
|
283 |
batch_size = 1
|
284 |
length = latent_seq_len
|
285 |
with torch.amp.autocast(device):
|
286 |
+
conditioning = training_wrapper.diffusion.conditioner(metadata, training_wrapper.device)
|
287 |
|
288 |
video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
|
289 |
+
conditioning['metaclip_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_clip_feat
|
290 |
+
conditioning['sync_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_sync_feat
|
291 |
|
292 |
yield "⏳ Inferring…", None
|
293 |
|
294 |
+
cond_inputs = training_wrapper.diffusion.get_conditioning_inputs(conditioning)
|
295 |
+
noise = torch.randn([batch_size, training_wrapper.diffusion.io_channels, length]).to(training_wrapper.device)
|
296 |
with torch.amp.autocast(device):
|
297 |
+
model = training_wrapper.diffusion.model
|
298 |
+
if training_wrapper.diffusion_objective == "v":
|
299 |
fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
300 |
+
elif training_wrapper.diffusion_objective == "rectified_flow":
|
301 |
import time
|
302 |
start_time = time.time()
|
303 |
fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
304 |
end_time = time.time()
|
305 |
execution_time = end_time - start_time
|
306 |
print(f"执行时间: {execution_time:.2f} 秒")
|
307 |
+
if training_wrapper.diffusion.pretransform is not None:
|
308 |
+
fakes = training_wrapper.diffusion.pretransform.decode(fakes)
|
309 |
|
310 |
audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
311 |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
|