waytan22 commited on
Commit
544833b
·
1 Parent(s): 93232f5
Files changed (1) hide show
  1. levo_inference.py +3 -1
levo_inference.py CHANGED
@@ -9,6 +9,7 @@ sys.path.append('.')
9
  import torch
10
 
11
  import json
 
12
  from omegaconf import OmegaConf
13
 
14
  from codeclm.trainer.codec_song_pl import CodecLM_PL
@@ -70,11 +71,12 @@ class LeVoInference(torch.nn.Module):
70
  params = {**self.default_params, **params}
71
  self.model.set_generation_params(**params)
72
 
73
- if prompt_audio_path is not None:
74
  pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
75
  melody_is_wav = True
76
  elif genre is not None and auto_prompt_path is not None:
77
  auto_prompt = torch.load(auto_prompt_path)
 
78
  if genre == "Auto":
79
  prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
80
  else:
 
9
  import torch
10
 
11
  import json
12
+ import numpy as np
13
  from omegaconf import OmegaConf
14
 
15
  from codeclm.trainer.codec_song_pl import CodecLM_PL
 
71
  params = {**self.default_params, **params}
72
  self.model.set_generation_params(**params)
73
 
74
+ if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
75
  pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
76
  melody_is_wav = True
77
  elif genre is not None and auto_prompt_path is not None:
78
  auto_prompt = torch.load(auto_prompt_path)
79
+ merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
80
  if genre == "Auto":
81
  prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
82
  else: